diff --git a/chirps/scan/tasks.py b/chirps/scan/tasks.py index fbcffd7a..96b66ec1 100644 --- a/chirps/scan/tasks.py +++ b/chirps/scan/tasks.py @@ -4,6 +4,9 @@ from celery import shared_task from django.utils import timezone +from embedding.utils import create_embedding +from target.providers.pinecone import PineconeTarget +from target.providers.redis import RedisTarget from target.models import BaseTarget from .models import Finding, Result, Scan @@ -33,7 +36,13 @@ def scan_task(scan_id): rules_run = 0 for rule in scan.plan.rules.all(): logger.info('Starting rule evaluation', extra={'id': rule.id}) - results = target.search(query=rule.query_string, max_results=100) + + if isinstance(target, (RedisTarget, PineconeTarget)): + embedding = create_embedding(rule.query_string, 'text-embedding-ada-002', 'OA', scan.user) + query = embedding.vectors + else: + query = rule.query_string + results = target.search(query, max_results=100) for text in results: diff --git a/chirps/target/forms.py b/chirps/target/forms.py index 4ee8f092..de907fae 100644 --- a/chirps/target/forms.py +++ b/chirps/target/forms.py @@ -14,7 +14,7 @@ class Meta: """Django Meta options for the RedisTargetForm.""" model = RedisTarget - fields = ['name', 'host', 'port', 'database_name', 'username', 'password'] + fields = ['name', 'host', 'port', 'database_name', 'username', 'password', 'index_name', 'text_field', 'embedding_field'] widgets = { 'name': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'Enter a name for the target'}), @@ -23,6 +23,9 @@ class Meta: 'database_name': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'Database name'}), 'username': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'guest'}), 'password': forms.PasswordInput(attrs={'class': 'form-control'}), + 'index_name': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'The name of the index in which documents are stored'}), + 'text_field': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'The document field in which text is stored'}), + 'embedding_field': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'The document field in which embeddings are stored'}), } diff --git a/chirps/target/providers/redis.py b/chirps/target/providers/redis.py index 5f8c9229..8d8ae6cd 100644 --- a/chirps/target/providers/redis.py +++ b/chirps/target/providers/redis.py @@ -1,8 +1,10 @@ """Logic for interfacing with a Redis target.""" +import numpy as np from logging import getLogger from django.db import models from redis import Redis +from redis.commands.search.query import Query from target.models import BaseTarget logger = getLogger(__name__) @@ -17,15 +19,40 @@ class RedisTarget(BaseTarget): username = models.CharField(max_length=256) password = models.CharField(max_length=2048, blank=True, null=True) + index_name = models.CharField(max_length=256) + text_field = models.CharField(max_length=256) + embedding_field = models.CharField(max_length=256) + # Name of the file in the ./target/static/ directory to use as a logo html_logo = 'target/redis-logo.png' html_name = 'Redis' html_description = 'Redis Vector Database' - def search(self, query: str, max_results: int) -> str: + def search(self, vectors: list, max_results: int) -> str: """Search the Redis target with the specified query.""" - logger.error('RedisTarget search not implemented') - raise NotImplementedError + client = Redis( + host=self.host, + port=self.port, + db=self.database_name, + password=self.password, + username=self.username, + ) + index = client.ft(self.index_name) + + score_field = 'vec_score' + vector_param = 'vec_param' + + vss_query = f'*=>[KNN {max_results} @{self.embedding_field} ${vector_param} AS {score_field}]' + return_fields = [self.embedding_field, self.text_field, score_field] + + query = Query(vss_query).sort_by(score_field).paging(0, max_results).return_fields(*return_fields).dialect(2) + embedding = np.array(vectors, dtype=np.float32).tostring() # type: ignore + params: dict[str, float] = {vector_param: embedding} + results = index.search(query, query_params=params) + + print(f'results: {results.docs}') + + return results.docs def test_connection(self) -> bool: """Ensure that the Redis target can be connected to."""