Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
satawatnack committed Mar 7, 2024
1 parent b5daf62 commit db6fbbd
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 27 deletions.
17 changes: 3 additions & 14 deletions adapter/verifiable_ai/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self):
self.api_key = os.getenv("API_KEY", None)

def parse_input(self, request: Request) -> Input:
print(request)
return Input(**request)

def verify_output(self, input_: Input, output: Output):
Expand All @@ -51,27 +52,15 @@ def verify_output(self, input_: Input, output: Output):
def parse_output(self, output: Output) -> Response:
return Response(**output)

async def call(self, input_: Input) -> Output:
async def call(self, input_: Input) -> Output:
client = httpx.AsyncClient()
response = await client.request(
"POST",
self.api_url,
headers={
"Authorization": "Bearer {}".format(self.api_key),
},
json={
"model": input_["model"],
"messages": [{
"role": "user",
"content": input_["messages"],
}],
"temperature": float(input_["temperature"]),
"top_p": float(input_["top_p"]),
"max_tokens": int(input_["max_tokens"]),
"stream": True if input_["stream"].lower() == "true" else False,
"safe_prompt": True if input_["safe_prompt"].lower() == "true" else False,
"random_seed": int(input_["random_seed"]),
}
json=input_
)

response.raise_for_status()
Expand Down
13 changes: 1 addition & 12 deletions adapter/verifiable_ai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,7 @@ async def call(self, input_: Input) -> Output:
headers={
"Authorization": "Bearer {}".format(self.api_key),
},
json={
"model": input_["model"],
"messages": [{
"role": "user",
"content": input_["messages"],
}],
"temperature": float(input_["temperature"]),
"top_p": float(input_["top_p"]),
"max_tokens": int(input_["max_tokens"]),
"stream": True if input_["stream"].lower() == "true" else False,
"seed": int(input_["seed"]),
}
json=input_
)

response.raise_for_status()
Expand Down
6 changes: 5 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,18 @@


@request_app.get("/")
@request_app.post("/")
async def request_data(request: Request) -> Any:
"""Requests data from the premium data source"""
report = ProviderResponseReport(
response_code=200,
created_at=datetime.utcnow(),
)
try:
return await adapter.unified_call(dict(request.query_params))
if request.method == "POST":
return await adapter.unified_call(await request.json())
else:
return await adapter.unified_call(dict(request.query_params))
except HTTPStatusError as e:
report.response_code = e.response.status_code
report.error_msg = str(e)
Expand Down

0 comments on commit db6fbbd

Please sign in to comment.