-
Notifications
You must be signed in to change notification settings - Fork 49
/
routers.py
141 lines (112 loc) · 5.72 KB
/
routers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""File containing root routes"""
from fastapi.routing import APIRouter
from fastapi import Depends, HTTPException
from fastapi import File, UploadFile
from fastapi.encoders import jsonable_encoder
import base64
import httpx
import asyncio
import copy
import json
import tempfile
from pathlib import Path
# Schemas
from src.schemas import FunctionCall
from src.schemas import (
ChatRequest,
FunctionCall,
AudioTranscriptRequest,
AudioTTSRequest,
)
from src.handlers import MainHandler
from src.data.data_models import Restaurant, Foods
# Services
from src.services import openai_service, functions
# Data
from sqlalchemy.orm import Session
from src.data.data_utils import get_db
def create_router(handler: MainHandler, CONFIG):
router = APIRouter()
client = handler.openai_client
## Question answering
@router.post("/chat/send_message")
async def send_message(prompt_request: ChatRequest):
"""Receives the chatlog from the user and answers"""
# Initializes the handler
prompt_handler = handler.prompt_handler
# Collects the messages in a list of dicts
messages = prompt_handler.get_messages(prompt_request)
# For function calling functionality
functions = []
if prompt_request.function_call:
functions = prompt_handler.get_functions()
try:
# Calls the main chat completion function
prompt_response = await openai_service.chat_completion(
messages=messages,
CONFIG=CONFIG,
functions=functions,
client=client
)
# Formats and returns
response = prompt_handler.prepare_response(prompt_response)
except Exception as e:
print(e)
response = {"response": "Oops there was an error, please try again", "function_call": None}
return response
@router.post("/chat/function_call")
async def function_call(function_call: FunctionCall):
"""Receives the function call from the frontend and executes it"""
# Preparing functions
function_call_properties = jsonable_encoder(function_call)
function_name = function_call_properties["name"]
function_arguments = json.loads(function_call_properties["arguments"])
# Configuring functions to be called - it should match the get_functions_signatures, otherwise we need to bypass it
available_functions = {
# Obs: all functions need to be async
"get_restaurant_pages": lambda kwargs: functions.find_restaurant_pages(CONFIG=CONFIG, **kwargs),
"open_restaurant_page": lambda kwargs: functions.open_restaurant_page(CONFIG=CONFIG, **kwargs),
"close_restaurant_page": lambda _: functions.dummy_function(), # dummy function - no need of information
"get_user_actions": lambda _: functions.dummy_function(), # dummy function - actions are stored in the frontend
"get_menu_of_restaurant": lambda kwargs: functions.get_menu_of_restaurant(CONFIG=CONFIG, **kwargs),
"add_food_to_cart": lambda kwargs: functions.add_food_to_cart(CONFIG=CONFIG, **kwargs),
"remove_food_from_cart": lambda kwargs: functions.remove_food_from_cart(CONFIG=CONFIG, **kwargs),
"open_shopping_cart": lambda _: functions.dummy_function(), # dummy function - no need of information
"close_shopping_cart": lambda _: functions.dummy_function(), # dummy function - no need of information
"place_order": lambda _: functions.dummy_function(), # dummy function - no need of information
"activate_handsfree": lambda _: functions.dummy_function(), # dummy function - no need of information
}
# Calling the function selected
function_response = await available_functions[function_name](function_arguments)
return {"response": function_response}
@router.post("/chat/transcribe")
async def generate_transcription(audio_req: AudioTranscriptRequest):
"""Receives the audio file from the frontend and transcribes it"""
# Initializes the handler
audio_handler = handler.audio_handler
# Extracts the audio segment of the file
audio_segment, _ = audio_handler.extract_audio_segment(audio_req.audio)
# Send it as a tempfile path to openai - because that's the acceptable way to do it
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as tmp_file:
audio_segment.export(tmp_file.name, format="mp3")
speech_filepath = Path(tmp_file.name)
transcripted_response = await openai_service.whisper(audio_file=open(speech_filepath, "rb"), CONFIG=CONFIG, client=client)
return {"response": transcripted_response}
@router.post("/chat/tts")
async def generate_tts(tts_req: AudioTTSRequest):
"""Receives the text from the frontend and generates the audio file"""
# Generates the audio file
audio = await openai_service.tts(text=tts_req.text, CONFIG=CONFIG, client=client)
return {"response": audio}
## Retrieving from the database
@router.get("/restaurants/")
def get_restaurants(db: Session = Depends(get_db)):
return db.query(Restaurant).all()
@router.get("/restaurants/{restaurant_id}/foods/")
def get_foods_from_restaurant(restaurant_id: int, db: Session = Depends(get_db)):
restaurant = db.query(Restaurant).filter(Restaurant.id == restaurant_id).first()
if not restaurant:
raise HTTPException(status_code=404, detail="Restaurant not found")
foods = db.query(Foods).filter(Foods.restaurant_id == restaurant_id).all()
return foods
return router