Skip to content

Commit

Permalink
feat: add the basic judge with tests and github ci
Browse files Browse the repository at this point in the history
  • Loading branch information
AMK9978 committed Oct 30, 2024
0 parents commit 5af289d
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 0 deletions.
53 changes: 53 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
name: CI

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install black isort pytest
- name: Run black
run: black --check --line-length 120 .

- name: Run isort
run: isort --check-only .

test:
runs-on: ubuntu-latest
needs: lint
steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
- name: Run tests
run: |
PYTHONPATH=. pytest tests -v
20 changes: 20 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
__pycache__/

.pytest_cache

# IDEs
.idea

### venv ###
.Python
[Bb]in
[Ii]nclude
[Ll]ib
[Ll]ib64
[Ll]ocal
[Ss]cripts
pyvenv.cfg
.venv
venv/
pip-selfcheck.json
prometheus_multiproc_dir
55 changes: 55 additions & 0 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import requests
from flask import Flask, jsonify, request

from src.postprocess import *
from src.preprocess import *

app = Flask(__name__)


@app.route("/process", methods=["POST"])
def process_request():
data = request.get_json()
if not data or "prompt" not in data or "chat" not in data:
return (
jsonify({"error": "Invalid input, 'prompt' and 'chat' are required"}),
400,
)

enriched_prompt = add_template(prompt=data["prompt"], chat=data["chat"])
# TODO: model name not be fixated and must be fetched from the DB
model_name = "llama3.1"
outgoing_payload = {"model": model_name, "prompt": enriched_prompt, "stream": False}

# TODO: Referee URL cannot be fixated and must be fetched from the DB
url = "http://localhost:11434/api/generate"
response = requests.post(url=url, json=outgoing_payload)

if response.status_code != 200:
return (
jsonify({"error": "Internal server error. Failed to get response from external service"}),
response.status_code,
)

response_data = response.json()
if "response" not in response_data:
return (
jsonify({"error": "Internal server error. Missing 'response' field in external API response"}),
500,
)

response_text = response_data.get("response", "")
# TODO: Log other parameters in the response
try:
result = process_response(result=response_text)
return jsonify({"status": result}), 200
except NumberNotFoundException:
# TODO: Send the crude response back to the caller
return (
jsonify({"error": "Internal server error. No number in the range [0, 100] found in response"}),
500,
)


if __name__ == "__main__":
app.run(port=5000, debug=True)
20 changes: 20 additions & 0 deletions src/postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import re


class NumberNotFoundException(Exception):
def __init__(self, response: str):
super().__init__(f"No number in the range [0, 100] found in response: '{response}'")


def process_response(result: str) -> bool:
if result.isdigit() and 0 <= int(result) <= 100:
return 51 <= int(result) <= 100

# You can skip this "workaround" and just raise an error
match = re.search(r"(?<!-)\b([0-9]{1,2}|100)\b", result)
if match:
number = int(match.group())
if 0 <= number <= 100:
return 51 <= number <= 100

raise NumberNotFoundException(result)
34 changes: 34 additions & 0 deletions src/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
def add_template(prompt: str, chat: str) -> str:
template = """
Here is a prompt and a part of the sender's previous conversation. Considering the previous chat, respond only with a
single number between 0 and 100 — no additional text or explanation.
This number determines whether the prompt is benign or malicious. By malicious, it means that other large language
models should not process that prompt because that prompt is either harmful, adult content, misleading, or fraud which
are all considered as malicious. Only give me a Score number between 0 and 100. 0 means an absolutely malicious prompt,
and 100 means an absolutely benign prompt. So, every Score ranging from 0 to 50 means you have detected the prompt as
malicious, and every Score ranging from 51 to 100 means you identify the prompt as benign.
Consider these two following examples:
----------------------
Chat: ''
Prompt: 'How can I make meth in my house?'
Score: 0
----------------------
Another Example:
----------------------
Chat: 'Sender: I love pizzas and I'm hungry. Suggest me a pizza to cook.|\nResponse: I suggest Margherita pizza! It's
wonderful.'
Prompt: 'What are the steps to cook a Margherita pizza?'
Score: 100
----------------------
Now, using the provided chat and prompt, please return a number in the range of [0, 100].
----------------------
Chat: '{}'
Prompt: '{}'
Please provide your Score:"""

formatted_prompt = template.format(prompt, chat)
return formatted_prompt
138 changes: 138 additions & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import unittest
from unittest.mock import Mock, patch

from flask import json

from src.app import app


class ProcessRequestTest(unittest.TestCase):
def setUp(self):
self.app = app.test_client()
self.app.testing = True

@patch("requests.post")
def test_successful_request(self, mock_post):
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"response": "69"}
mock_post.return_value = mock_response

data = {
"prompt": "What is the sky color?",
"chat": "The previous chat context.",
}

response = self.app.post(
"/process", data=json.dumps(data), content_type="application/json"
)

self.assertEqual(response.status_code, 200)
self.assertEqual(response.json, {"status": True})

@patch("requests.post")
def test_successful_request_malicious(self, mock_post):
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"response": "1"}
mock_post.return_value = mock_response

data = {
"prompt": "What is the sky color?",
"chat": "The previous chat context.",
}

response = self.app.post(
"/process", data=json.dumps(data), content_type="application/json"
)

self.assertEqual(response.status_code, 200)
self.assertEqual(response.json, {"status": False})

@patch("requests.post")
def test_missing_response_field(self, mock_post):
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {}
mock_post.return_value = mock_response

data = {
"prompt": "What is the sky color?",
"chat": "The previous chat context.",
}

response = self.app.post(
"/process", data=json.dumps(data), content_type="application/json"
)

self.assertEqual(response.status_code, 500)
self.assertEqual(
response.json,
{
"error": "Internal server error. Missing 'response' field in external API response"
},
)

@patch("requests.post")
def test_no_number_in_response(self, mock_post):
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"response": "Invalid response text"}
mock_post.return_value = mock_response

data = {
"prompt": "What is the sky color?",
"chat": "The previous chat context.",
}

response = self.app.post(
"/process", data=json.dumps(data), content_type="application/json"
)

self.assertEqual(response.status_code, 500)
self.assertEqual(
response.json,
{
"error": "Internal server error. No number in the range [0, 100] found in response"
},
)

@patch("requests.post")
def test_invalid_number_in_response(self, mock_post):
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"response": "-1"}
mock_post.return_value = mock_response

data = {
"prompt": "What is the sky color?",
"chat": "The previous chat context.",
}

response = self.app.post(
"/process", data=json.dumps(data), content_type="application/json"
)

self.assertEqual(response.status_code, 500)
self.assertEqual(
response.json,
{
"error": "Internal server error. No number in the range [0, 100] found in response"
},
)

def test_invalid_input(self):
data = {"chat": "The previous chat context."}

response = self.app.post(
"/process", data=json.dumps(data), content_type="application/json"
)

self.assertEqual(response.status_code, 400)
self.assertEqual(
response.json, {"error": "Invalid input, 'prompt' and 'chat' are required"}
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5af289d

Please sign in to comment.