Skip to content

Commit

Permalink
implement search for RedisTarget
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-nork committed Jul 13, 2023
1 parent 114313c commit bbd9cd6
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
11 changes: 10 additions & 1 deletion chirps/scan/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from celery import shared_task
from django.utils import timezone
from embedding.utils import create_embedding
from target.providers.pinecone import PineconeTarget
from target.providers.redis import RedisTarget
from target.models import BaseTarget

from .models import Finding, Result, Scan
Expand Down Expand Up @@ -33,7 +36,13 @@ def scan_task(scan_id):
rules_run = 0
for rule in scan.plan.rules.all():
logger.info('Starting rule evaluation', extra={'id': rule.id})
results = target.search(query=rule.query_string, max_results=100)

if isinstance(target, (RedisTarget, PineconeTarget)):
embedding = create_embedding(rule.query_string, 'text-embedding-ada-002', 'OA', scan.user)
query = embedding.vectors
else:
query = rule.query_string
results = target.search(query, max_results=100)

for text in results:

Expand Down
5 changes: 4 additions & 1 deletion chirps/target/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Meta:
"""Django Meta options for the RedisTargetForm."""

model = RedisTarget
fields = ['name', 'host', 'port', 'database_name', 'username', 'password']
fields = ['name', 'host', 'port', 'database_name', 'username', 'password', 'index_name', 'text_field', 'embedding_field']

widgets = {
'name': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'Enter a name for the target'}),
Expand All @@ -23,6 +23,9 @@ class Meta:
'database_name': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'Database name'}),
'username': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'guest'}),
'password': forms.PasswordInput(attrs={'class': 'form-control'}),
'index_name': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'The name of the index in which documents are stored'}),
'text_field': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'The document field in which text is stored'}),
'embedding_field': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'The document field in which embeddings are stored'}),
}


Expand Down
33 changes: 30 additions & 3 deletions chirps/target/providers/redis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Logic for interfacing with a Redis target."""
import numpy as np
from logging import getLogger

from django.db import models
from redis import Redis
from redis.commands.search.query import Query
from target.models import BaseTarget

logger = getLogger(__name__)
Expand All @@ -17,15 +19,40 @@ class RedisTarget(BaseTarget):
username = models.CharField(max_length=256)
password = models.CharField(max_length=2048, blank=True, null=True)

index_name = models.CharField(max_length=256)
text_field = models.CharField(max_length=256)
embedding_field = models.CharField(max_length=256)

# Name of the file in the ./target/static/ directory to use as a logo
html_logo = 'target/redis-logo.png'
html_name = 'Redis'
html_description = 'Redis Vector Database'

def search(self, query: str, max_results: int) -> str:
def search(self, vectors: list, max_results: int) -> str:
"""Search the Redis target with the specified query."""
logger.error('RedisTarget search not implemented')
raise NotImplementedError
client = Redis(
host=self.host,
port=self.port,
db=self.database_name,
password=self.password,
username=self.username,
)
index = client.ft(self.index_name)

score_field = 'vec_score'
vector_param = 'vec_param'

vss_query = f'*=>[KNN {max_results} @{self.embedding_field} ${vector_param} AS {score_field}]'
return_fields = [self.embedding_field, self.text_field, score_field]

query = Query(vss_query).sort_by(score_field).paging(0, max_results).return_fields(*return_fields).dialect(2)
embedding = np.array(vectors, dtype=np.float32).tostring() # type: ignore
params: dict[str, float] = {vector_param: embedding}
results = index.search(query, query_params=params)

print(f'results: {results.docs}')

return results.docs

def test_connection(self) -> bool:
"""Ensure that the Redis target can be connected to."""
Expand Down

0 comments on commit bbd9cd6

Please sign in to comment.