Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RedisTarget search #62

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 13 additions & 32 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -1,34 +1,15 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/python
{
"name": "Python 3",
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
"image": "mcr.microsoft.com/devcontainers/python:0-3.11",
"features": {
"ghcr.io/devcontainers-contrib/features/rabbitmq-asdf:1": {
"version": "latest",
"erlangVersion": "latest"
},
"ghcr.io/devcontainers-contrib/features/redis-homebrew:1": {
"version": "latest"
},
"ghcr.io/devcontainers-contrib/features/vault-asdf:2": {
"version": "latest"
}
}

// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},

// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "pip3 install --user -r requirements.txt",

// Configure tool-specific properties.
// "customizations": {},

// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
"name": "Python 3",
"dockerComposeFile": "docker-compose.yml",
"service": "app",
"workspaceFolder": "/workspace",
"shutdownAction": "stopCompose",
"features": {
"ghcr.io/devcontainers-contrib/features/rabbitmq-asdf:1": {},
"ghcr.io/devcontainers-contrib/features/vault-asdf:2": {},
"ghcr.io/devcontainers/features/docker-in-docker:2": {}
},
"forwardPorts": [4369, 5672, 6379, 8000],
"postCreateCommand": "docker-compose up -d",
"runServices": ["redis", "celery", "rabbitmq", "vault"]
}
64 changes: 64 additions & 0 deletions .devcontainer/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
version: '3.8'

services:
app:
image: mcr.microsoft.com/devcontainers/python:0-3.11
volumes:
- ..:/workspace:cached
command: sleep infinity
depends_on:
- redis
- celery
- rabbitmq
- vault
networks:
default:
aliases:
- localhost

celery:
build:
context: .
dockerfile: Dockerfile.celery
volumes:
- .:/workspace:cached
depends_on:
- rabbitmq
networks:
default:
aliases:
- localhost

redis:
image: redis/redis-stack:latest
ports:
- 6379:6379
networks:
default:
aliases:
- localhost

rabbitmq:
image: rabbitmq:latest
ports:
- 5672:5672
- 15672:15672
networks:
default:
aliases:
- localhost

vault:
image: hashicorp/vault:latest
ports:
- 8200:8200
environment:
VAULT_DEV_ROOT_TOKEN_ID: myroot
VAULT_DEV_LISTEN_ADDRESS: 0.0.0.0:8200
networks:
default:
aliases:
- localhost

networks:
default:
6 changes: 3 additions & 3 deletions chirps/base_app/management/commands/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def add_arguments(self, parser):
def handle(self, *args, **options):
"""Handle redis command"""
if options['start']:
os.system('redis-server --daemonize yes')
os.system('docker-compose -f /workspace/.devcontainer/docker-compose.yml up -d redis')
elif options['stop']:
os.system('redis-cli shutdown')
os.system('docker-compose -f /workspace/.devcontainer/docker-compose.yml down')
elif options['status']:
os.system('redis-cli ping')
os.system('docker-compose -f /workspace/.devcontainer/docker-compose.yml ps')
13 changes: 12 additions & 1 deletion chirps/scan/tasks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Celery tasks for the scan application."""
import json
import re
from logging import getLogger

from celery import shared_task
from django.utils import timezone
from embedding.utils import create_embedding
from target.providers.pinecone import PineconeTarget
from target.providers.redis import RedisTarget
from target.models import BaseTarget

from .models import Finding, Result, Scan
Expand All @@ -27,13 +31,20 @@ def scan_task(scan_id):
# Need to perform a secondary query in order to fetch the derrived class
# This magic is handled by django-polymorphic
target = BaseTarget.objects.get(id=scan.target.id)
embed_query = isinstance(target, (RedisTarget, PineconeTarget))

# Now that we have the derrived class, call its implementation of search()
total_rules = scan.plan.rules.all().count()
rules_run = 0
for rule in scan.plan.rules.all():
if embed_query:
embedding = create_embedding(rule.query_string, 'text-embedding-ada-002', 'OA', scan.user)
print(f'vectors: {embedding.vectors}')
query = embedding.vectors
else:
query = rule.query_string
logger.info('Starting rule evaluation', extra={'id': rule.id})
results = target.search(query=rule.query_string, max_results=100)
results = target.search(query, max_results=100)

for text in results:

Expand Down
5 changes: 4 additions & 1 deletion chirps/target/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Meta:
"""Django Meta options for the RedisTargetForm."""

model = RedisTarget
fields = ['name', 'host', 'port', 'database_name', 'username', 'password']
fields = ['name', 'host', 'port', 'database_name', 'username', 'password', 'index_name', 'text_field', 'embedding_field']

widgets = {
'name': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'Enter a name for the target'}),
Expand All @@ -23,6 +23,9 @@ class Meta:
'database_name': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'Database name'}),
'username': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'guest'}),
'password': forms.PasswordInput(attrs={'class': 'form-control'}),
'index_name': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'The name of the index in which documents are stored'}),
'text_field': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'The document field in which text is stored'}),
'embedding_field': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'The document field in which embeddings are stored'}),
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Generated by Django 4.2.2 on 2023-07-12 13:59

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("target", "0001_initial"),
]

operations = [
migrations.AddField(
model_name="redistarget",
name="embedding_field",
field=models.CharField(default="embedding", max_length=256),
preserve_default=False,
),
migrations.AddField(
model_name="redistarget",
name="index_name",
field=models.CharField(default="index", max_length=256),
preserve_default=False,
),
migrations.AddField(
model_name="redistarget",
name="text_field",
field=models.CharField(default="content", max_length=256),
preserve_default=False,
),
]
34 changes: 31 additions & 3 deletions chirps/target/providers/redis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Logic for interfacing with a Redis target."""
import numpy as np
from logging import getLogger

from django.db import models
from redis import Redis
from redis.commands.search.query import Query
from target.models import BaseTarget

logger = getLogger(__name__)
Expand All @@ -17,15 +19,40 @@ class RedisTarget(BaseTarget):
username = models.CharField(max_length=256)
password = models.CharField(max_length=2048, blank=True, null=True)

index_name = models.CharField(max_length=256)
text_field = models.CharField(max_length=256)
embedding_field = models.CharField(max_length=256)

# Name of the file in the ./target/static/ directory to use as a logo
html_logo = 'target/redis-logo.png'
html_name = 'Redis'
html_description = 'Redis Vector Database'

def search(self, query: str, max_results: int) -> str:
def search(self, vectors: list, max_results: int) -> str:
"""Search the Redis target with the specified query."""
logger.error('RedisTarget search not implemented')
raise NotImplementedError
client = Redis(
host=self.host,
port=self.port,
db=self.database_name,
password=self.password,
username=self.username,
)
index = client.ft(self.index_name)

score_field = 'vec_score'
vector_param = 'vec_param'

vss_query = f'*=>[KNN {max_results} @{self.embedding_field} ${vector_param} AS {score_field}]'
return_fields = [self.embedding_field, self.text_field, score_field]

query = Query(vss_query).sort_by(score_field).paging(0, max_results).return_fields(*return_fields).dialect(2)
embedding = np.array(vectors, dtype=np.float32).tostring() # type: ignore
params: dict[str, float] = {vector_param: embedding}
results = index.search(query, query_params=params)

print(f'results: {results.docs}')

return results.docs

def test_connection(self) -> bool:
"""Ensure that the Redis target can be connected to."""
Expand All @@ -36,4 +63,5 @@ def test_connection(self) -> bool:
password=self.password,
username=self.username,
)

return client.ping()
Loading