diff --git a/adapter/verifiable_ai/mistral.py b/adapter/verifiable_ai/mistral.py index dab4a82..2472749 100644 --- a/adapter/verifiable_ai/mistral.py +++ b/adapter/verifiable_ai/mistral.py @@ -43,6 +43,7 @@ def __init__(self): self.api_key = os.getenv("API_KEY", None) def parse_input(self, request: Request) -> Input: + print(request) return Input(**request) def verify_output(self, input_: Input, output: Output): @@ -51,7 +52,7 @@ def verify_output(self, input_: Input, output: Output): def parse_output(self, output: Output) -> Response: return Response(**output) - async def call(self, input_: Input) -> Output: + async def call(self, input_: Input) -> Output: client = httpx.AsyncClient() response = await client.request( "POST", @@ -59,19 +60,7 @@ async def call(self, input_: Input) -> Output: headers={ "Authorization": "Bearer {}".format(self.api_key), }, - json={ - "model": input_["model"], - "messages": [{ - "role": "user", - "content": input_["messages"], - }], - "temperature": float(input_["temperature"]), - "top_p": float(input_["top_p"]), - "max_tokens": int(input_["max_tokens"]), - "stream": True if input_["stream"].lower() == "true" else False, - "safe_prompt": True if input_["safe_prompt"].lower() == "true" else False, - "random_seed": int(input_["random_seed"]), - } + json=input_ ) response.raise_for_status() diff --git a/adapter/verifiable_ai/openai.py b/adapter/verifiable_ai/openai.py index b2268ad..d522761 100644 --- a/adapter/verifiable_ai/openai.py +++ b/adapter/verifiable_ai/openai.py @@ -57,18 +57,7 @@ async def call(self, input_: Input) -> Output: headers={ "Authorization": "Bearer {}".format(self.api_key), }, - json={ - "model": input_["model"], - "messages": [{ - "role": "user", - "content": input_["messages"], - }], - "temperature": float(input_["temperature"]), - "top_p": float(input_["top_p"]), - "max_tokens": int(input_["max_tokens"]), - "stream": True if input_["stream"].lower() == "true" else False, - "seed": int(input_["seed"]), - } + json=input_ ) response.raise_for_status() diff --git a/app/main.py b/app/main.py index 240b229..06fafe0 100644 --- a/app/main.py +++ b/app/main.py @@ -94,6 +94,7 @@ @request_app.get("/") +@request_app.post("/") async def request_data(request: Request) -> Any: """Requests data from the premium data source""" report = ProviderResponseReport( @@ -101,7 +102,10 @@ async def request_data(request: Request) -> Any: created_at=datetime.utcnow(), ) try: - return await adapter.unified_call(dict(request.query_params)) + if request.method == "POST": + return await adapter.unified_call(await request.json()) + else: + return await adapter.unified_call(dict(request.query_params)) except HTTPStatusError as e: report.response_code = e.response.status_code report.error_msg = str(e)