diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index e1232d2a..2ec324de 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,34 +1,15 @@ -// For format details, see https://aka.ms/devcontainer.json. For config options, see the -// README at: https://github.com/devcontainers/templates/tree/main/src/python { - "name": "Python 3", - // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile - "image": "mcr.microsoft.com/devcontainers/python:0-3.11", - "features": { - "ghcr.io/devcontainers-contrib/features/rabbitmq-asdf:1": { - "version": "latest", - "erlangVersion": "latest" - }, - "ghcr.io/devcontainers-contrib/features/redis-homebrew:1": { - "version": "latest" - }, - "ghcr.io/devcontainers-contrib/features/vault-asdf:2": { - "version": "latest" - } - } - - // Features to add to the dev container. More info: https://containers.dev/features. - // "features": {}, - - // Use 'forwardPorts' to make a list of ports inside the container available locally. - // "forwardPorts": [], - - // Use 'postCreateCommand' to run commands after the container is created. - // "postCreateCommand": "pip3 install --user -r requirements.txt", - - // Configure tool-specific properties. - // "customizations": {}, - - // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. - // "remoteUser": "root" + "name": "Python 3", + "dockerComposeFile": "docker-compose.yml", + "service": "app", + "workspaceFolder": "/workspace", + "shutdownAction": "stopCompose", + "features": { + "ghcr.io/devcontainers-contrib/features/rabbitmq-asdf:1": {}, + "ghcr.io/devcontainers-contrib/features/vault-asdf:2": {}, + "ghcr.io/devcontainers/features/docker-in-docker:2": {} + }, + "forwardPorts": [4369, 5672, 6379, 8000], + "postCreateCommand": "docker-compose up -d", + "runServices": ["redis", "celery", "rabbitmq", "vault"] } diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml new file mode 100644 index 00000000..e30096d2 --- /dev/null +++ b/.devcontainer/docker-compose.yml @@ -0,0 +1,64 @@ +version: '3.8' + +services: + app: + image: mcr.microsoft.com/devcontainers/python:0-3.11 + volumes: + - ..:/workspace:cached + command: sleep infinity + depends_on: + - redis + - celery + - rabbitmq + - vault + networks: + default: + aliases: + - localhost + + celery: + build: + context: . + dockerfile: Dockerfile.celery + volumes: + - .:/workspace:cached + depends_on: + - rabbitmq + networks: + default: + aliases: + - localhost + + redis: + image: redis/redis-stack:latest + ports: + - 6379:6379 + networks: + default: + aliases: + - localhost + + rabbitmq: + image: rabbitmq:latest + ports: + - 5672:5672 + - 15672:15672 + networks: + default: + aliases: + - localhost + + vault: + image: hashicorp/vault:latest + ports: + - 8200:8200 + environment: + VAULT_DEV_ROOT_TOKEN_ID: myroot + VAULT_DEV_LISTEN_ADDRESS: 0.0.0.0:8200 + networks: + default: + aliases: + - localhost + +networks: + default: diff --git a/chirps/base_app/management/commands/redis.py b/chirps/base_app/management/commands/redis.py index 1444e306..5f5c3bd6 100644 --- a/chirps/base_app/management/commands/redis.py +++ b/chirps/base_app/management/commands/redis.py @@ -18,8 +18,8 @@ def add_arguments(self, parser): def handle(self, *args, **options): """Handle redis command""" if options['start']: - os.system('redis-server --daemonize yes') + os.system('docker-compose -f /workspace/.devcontainer/docker-compose.yml up -d redis') elif options['stop']: - os.system('redis-cli shutdown') + os.system('docker-compose -f /workspace/.devcontainer/docker-compose.yml down') elif options['status']: - os.system('redis-cli ping') + os.system('docker-compose -f /workspace/.devcontainer/docker-compose.yml ps') diff --git a/chirps/scan/tasks.py b/chirps/scan/tasks.py index fbcffd7a..4e1443b5 100644 --- a/chirps/scan/tasks.py +++ b/chirps/scan/tasks.py @@ -1,9 +1,13 @@ """Celery tasks for the scan application.""" +import json import re from logging import getLogger from celery import shared_task from django.utils import timezone +from embedding.utils import create_embedding +from target.providers.pinecone import PineconeTarget +from target.providers.redis import RedisTarget from target.models import BaseTarget from .models import Finding, Result, Scan @@ -27,13 +31,20 @@ def scan_task(scan_id): # Need to perform a secondary query in order to fetch the derrived class # This magic is handled by django-polymorphic target = BaseTarget.objects.get(id=scan.target.id) + embed_query = isinstance(target, (RedisTarget, PineconeTarget)) # Now that we have the derrived class, call its implementation of search() total_rules = scan.plan.rules.all().count() rules_run = 0 for rule in scan.plan.rules.all(): + if embed_query: + embedding = create_embedding(rule.query_string, 'text-embedding-ada-002', 'OA', scan.user) + print(f'vectors: {embedding.vectors}') + query = embedding.vectors + else: + query = rule.query_string logger.info('Starting rule evaluation', extra={'id': rule.id}) - results = target.search(query=rule.query_string, max_results=100) + results = target.search(query, max_results=100) for text in results: diff --git a/chirps/target/forms.py b/chirps/target/forms.py index 4ee8f092..de907fae 100644 --- a/chirps/target/forms.py +++ b/chirps/target/forms.py @@ -14,7 +14,7 @@ class Meta: """Django Meta options for the RedisTargetForm.""" model = RedisTarget - fields = ['name', 'host', 'port', 'database_name', 'username', 'password'] + fields = ['name', 'host', 'port', 'database_name', 'username', 'password', 'index_name', 'text_field', 'embedding_field'] widgets = { 'name': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'Enter a name for the target'}), @@ -23,6 +23,9 @@ class Meta: 'database_name': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'Database name'}), 'username': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'guest'}), 'password': forms.PasswordInput(attrs={'class': 'form-control'}), + 'index_name': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'The name of the index in which documents are stored'}), + 'text_field': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'The document field in which text is stored'}), + 'embedding_field': forms.TextInput(attrs={'class': 'form-control', 'placeholder': 'The document field in which embeddings are stored'}), } diff --git a/chirps/target/migrations/0002_redistarget_embedding_field_redistarget_index_name_and_more.py b/chirps/target/migrations/0002_redistarget_embedding_field_redistarget_index_name_and_more.py new file mode 100644 index 00000000..feadff62 --- /dev/null +++ b/chirps/target/migrations/0002_redistarget_embedding_field_redistarget_index_name_and_more.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.2 on 2023-07-12 13:59 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("target", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="redistarget", + name="embedding_field", + field=models.CharField(default="embedding", max_length=256), + preserve_default=False, + ), + migrations.AddField( + model_name="redistarget", + name="index_name", + field=models.CharField(default="index", max_length=256), + preserve_default=False, + ), + migrations.AddField( + model_name="redistarget", + name="text_field", + field=models.CharField(default="content", max_length=256), + preserve_default=False, + ), + ] diff --git a/chirps/target/providers/redis.py b/chirps/target/providers/redis.py index 5f8c9229..99c08b0b 100644 --- a/chirps/target/providers/redis.py +++ b/chirps/target/providers/redis.py @@ -1,8 +1,10 @@ """Logic for interfacing with a Redis target.""" +import numpy as np from logging import getLogger from django.db import models from redis import Redis +from redis.commands.search.query import Query from target.models import BaseTarget logger = getLogger(__name__) @@ -17,15 +19,40 @@ class RedisTarget(BaseTarget): username = models.CharField(max_length=256) password = models.CharField(max_length=2048, blank=True, null=True) + index_name = models.CharField(max_length=256) + text_field = models.CharField(max_length=256) + embedding_field = models.CharField(max_length=256) + # Name of the file in the ./target/static/ directory to use as a logo html_logo = 'target/redis-logo.png' html_name = 'Redis' html_description = 'Redis Vector Database' - def search(self, query: str, max_results: int) -> str: + def search(self, vectors: list, max_results: int) -> str: """Search the Redis target with the specified query.""" - logger.error('RedisTarget search not implemented') - raise NotImplementedError + client = Redis( + host=self.host, + port=self.port, + db=self.database_name, + password=self.password, + username=self.username, + ) + index = client.ft(self.index_name) + + score_field = 'vec_score' + vector_param = 'vec_param' + + vss_query = f'*=>[KNN {max_results} @{self.embedding_field} ${vector_param} AS {score_field}]' + return_fields = [self.embedding_field, self.text_field, score_field] + + query = Query(vss_query).sort_by(score_field).paging(0, max_results).return_fields(*return_fields).dialect(2) + embedding = np.array(vectors, dtype=np.float32).tostring() # type: ignore + params: dict[str, float] = {vector_param: embedding} + results = index.search(query, query_params=params) + + print(f'results: {results.docs}') + + return results.docs def test_connection(self) -> bool: """Ensure that the Redis target can be connected to.""" @@ -36,4 +63,5 @@ def test_connection(self) -> bool: password=self.password, username=self.username, ) + return client.ping()