Skip to content

Commit

Permalink
Merge pull request #98 from weaviate/add-support-for-auth
Browse files Browse the repository at this point in the history
Add support for auth
  • Loading branch information
antas-marcin authored Nov 15, 2024
2 parents 1427253 + f240b95 commit 11198ed
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 17 deletions.
66 changes: 50 additions & 16 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
import os
from typing import Optional, List
from logging import getLogger
from fastapi import FastAPI, Response, status
from fastapi import FastAPI, Depends, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from typing import Union
from config import TRUST_REMOTE_CODE
from config import TRUST_REMOTE_CODE, get_allowed_tokens
from vectorizer import Vectorizer, VectorInput
from meta import Meta


app = FastAPI()
logger = getLogger("uvicorn")

vec: Vectorizer
meta_config: Meta
logger = getLogger("uvicorn")

get_bearer_token = HTTPBearer(auto_error=False)
allowed_tokens: List[str] = None


def is_authorized(auth: Optional[HTTPAuthorizationCredentials]) -> bool:
if allowed_tokens is not None and (
auth is None or auth.credentials not in allowed_tokens
):
return False
return True


@app.on_event("startup")
def startup_event():
async def lifespan(app: FastAPI):
global vec
global meta_config
global allowed_tokens

allowed_tokens = get_allowed_tokens()

cuda_env = os.getenv("ENABLE_CUDA")
cuda_per_process_memory_fraction = 1.0
Expand Down Expand Up @@ -113,6 +128,10 @@ def log_info_about_onnx(onnx_runtime: bool):
model_name,
trust_remote_code,
)
yield


app = FastAPI(lifespan=lifespan)


@app.get("/.well-known/live", response_class=Response)
Expand All @@ -122,17 +141,32 @@ async def live_and_ready(response: Response):


@app.get("/meta")
def meta():
return meta_config.get()
def meta(
response: Response,
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
):
if is_authorized(auth):
return meta_config.get()
else:
response.status_code = status.HTTP_401_UNAUTHORIZED
return {"error": "Unauthorized"}


@app.post("/vectors")
@app.post("/vectors/")
async def vectorize(item: VectorInput, response: Response):
try:
vector = await vec.vectorize(item.text, item.config)
return {"text": item.text, "vector": vector.tolist(), "dim": len(vector)}
except Exception as e:
logger.exception("Something went wrong while vectorizing data.")
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return {"error": str(e)}
async def vectorize(
item: VectorInput,
response: Response,
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
):
if is_authorized(auth):
try:
vector = await vec.vectorize(item.text, item.config)
return {"text": item.text, "vector": vector.tolist(), "dim": len(vector)}
except Exception as e:
logger.exception("Something went wrong while vectorizing data.")
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return {"error": str(e)}
else:
response.status_code = status.HTTP_401_UNAUTHORIZED
return {"error": "Unauthorized"}
12 changes: 11 additions & 1 deletion cicd/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@ local_repo=${LOCAL_REPO?Variable LOCAL_REPO is required}

pip3 install -r requirements-test.txt

docker run -d -it -p "8000:8080" "$local_repo"
echo "Running tests with authorization on"

container_id=$(docker run -d -it -e AUTHENTICATION_ALLOWED_TOKENS='token1,token2' -p "8000:8080" "$local_repo")

python3 smoke_auth_test.py

docker stop $container_id

echo "Running tests without authorization"

container_id=$(docker run -d -it -p "8000:8080" "$local_repo")

python3 smoke_test.py
pytest test_app.py
8 changes: 8 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
import os
from typing import List

TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", False)


def get_allowed_tokens() -> List[str] | None:
if (
tokens := os.getenv("AUTHENTICATION_ALLOWED_TOKENS", "").strip()
) and tokens != "":
return tokens.strip().split(",")
93 changes: 93 additions & 0 deletions smoke_auth_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import time
import unittest
import requests


class SmokeTest(unittest.TestCase):
def setUp(self):
self.url = "http://localhost:8000"

for i in range(0, 100):
try:
res = requests.get(self.url + "/.well-known/ready")
if res.status_code == 204:
return
else:
raise Exception("status code is {}".format(res.status_code))
except Exception as e:
print("Attempt {}: {}".format(i, e))
time.sleep(1)

raise Exception("did not start up")

def test_well_known_ready(self):
res = requests.get(self.url + "/.well-known/ready")

self.assertEqual(res.status_code, 204)

def test_well_known_live(self):
res = requests.get(self.url + "/.well-known/live")

self.assertEqual(res.status_code, 204)

def test_meta_unauthorized(self):
res = requests.get(self.url + "/meta")

self.assertEqual(res.status_code, 401)
self.assertEqual(res.json()["error"], "Unauthorized")

headers = {"Authorization": "Bearer bad-token"}
res = requests.get(self.url + "/meta", headers=headers)

self.assertEqual(res.status_code, 401)
self.assertEqual(res.json()["error"], "Unauthorized")

def test_meta(self):
headers = {"Authorization": "Bearer token1"}
res = requests.get(self.url + "/meta", headers=headers)

self.assertEqual(res.status_code, 200)
self.assertIsInstance(res.json(), dict)

def test_vectorizing_unauthorized(self):
req_body = {"text": "The London Eye is a ferris wheel at the River Thames."}
res = requests.post(self.url + "/vectors", json=req_body)

self.assertEqual(res.status_code, 401)
self.assertEqual(res.json()["error"], "Unauthorized")

headers = {"Authorization": "Bearer bad-token"}
res = requests.post(self.url + "/vectors", json=req_body, headers=headers)

self.assertEqual(res.status_code, 401)
self.assertEqual(res.json()["error"], "Unauthorized")

def test_vectorizing(self):
def get_req_body(task_type: str = ""):
req_body = {"text": "The London Eye is a ferris wheel at the River Thames."}
if task_type != "":
req_body["config"] = {"task_type": task_type}
return req_body

def try_to_vectorize(url, task_type: str = ""):
print(f"url: {url}")
req_body = get_req_body(task_type)

headers = {"Authorization": "Bearer token2"}
res = requests.post(url, json=req_body, headers=headers)
resBody = res.json()

self.assertEqual(200, res.status_code)

# below tests that what we deem a reasonable vector is returned. We are
# aware of 384 and 768 dim vectors, which should both fall in that
# range
self.assertTrue(len(resBody["vector"]) > 100)
print(f"vector dimensions are: {len(resBody['vector'])}")

try_to_vectorize(self.url + "/vectors/")
try_to_vectorize(self.url + "/vectors")


if __name__ == "__main__":
unittest.main()

0 comments on commit 11198ed

Please sign in to comment.