Skip to content

Commit

Permalink
Update URL path setting
Browse files Browse the repository at this point in the history
  • Loading branch information
yuting1214 authored Jun 25, 2024
1 parent 977a098 commit 4ed84ec
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 25 deletions.
4 changes: 2 additions & 2 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from starlette.middleware.sessions import SessionMiddleware
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from backend.app.core.init_settings import args, global_settings
from backend.app.core.init_settings import args
from backend.app.core.constants import API_BASE_URL
from backend.app.api.v1.endpoints import (
user,
Expand Down Expand Up @@ -54,7 +54,7 @@ async def lifespan(app: FastAPI):
# Set Middleware
# Define the allowed origins
origins = [
global_settings.API_BASE_URL,
API_BASE_URL,
"http://localhost",
"http://localhost:5000",
]
Expand Down
15 changes: 8 additions & 7 deletions frontend/gradio/text/event_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import gradio as gr
import json
from urllib.parse import urljoin
from typing import List, Optional, Tuple
from backend.app.core.constants import API_BASE_URL, MEMORY_WINDOW_SIZE
from frontend.gradio.utils import add_message_to_db, start_chat
Expand Down Expand Up @@ -60,7 +61,7 @@ def enable_vote_buttons_arena():
def add_rating_to_db(state: gr.State, rating_type: str):
chat_id = state.value['current_chat_id']
if chat_id:
rating_api_endpoint = API_BASE_URL + "api/v1/ratings/"
rating_api_endpoint = urljoin(API_BASE_URL, "api/v1/ratings/")
rating_data = {
"chat_id": chat_id,
"rating_type": rating_type
Expand Down Expand Up @@ -106,7 +107,7 @@ def llm_text_completion_stream(state: gr.State, model_name: str, history: List[T
model_endpoint = state.value['model_map'][model_name]

# Call the LLM model with streaming
llm_api_endpoint = API_BASE_URL + "api/v1/llm/generation/stream/text"
llm_api_endpoint = urljoin(API_BASE_URL, "api/v1/llm/generation/stream/text")
llm_data = {
"user_id": user_id,
"api_endpoint": model_endpoint,
Expand Down Expand Up @@ -139,7 +140,7 @@ def llm_text_completion_stream(state: gr.State, model_name: str, history: List[T
time.sleep(0.01) # Control the streaming speed for the interface

# Add Message data(user) into DB
message_api_endpoint = API_BASE_URL + "api/v1/messages/"
message_api_endpoint = urljoin(API_BASE_URL, "api/v1/messages/")
message_data_user = {
"user_id": user_id,
"content": user_message,
Expand Down Expand Up @@ -182,7 +183,7 @@ async def llm_text_completion_stream_async(
state.value['current_chat_id'] = chat_id

# Call the LLM model with streaming
llm_api_endpoint = API_BASE_URL + "api/v1/llm/generation/stream/text"
llm_api_endpoint = urljoin(API_BASE_URL, "api/v1/llm/generation/stream/text")
llm_params = {
"system_prompt": system_prompt,
"temperature": temperature,
Expand Down Expand Up @@ -243,7 +244,7 @@ async def llm_text_completion_memory_stream_async(
state.value['current_chat_id'] = chat_id

# Call the LLM model with streaming
llm_api_endpoint = API_BASE_URL + "api/v1/llm/generation/stream/text/memory"
llm_api_endpoint = urljoin(API_BASE_URL, "api/v1/llm/generation/stream/text/memory")
llm_params = {
"system_prompt": system_prompt,
"temperature": temperature,
Expand Down Expand Up @@ -310,7 +311,7 @@ async def llm_text_completion_stream_arena(
state.value['current_chat_id'] = chat_id

# Call the LLM model with streaming
llm_api_endpoint = API_BASE_URL + "api/v1/llm/generation/stream/text"
llm_api_endpoint = urljoin(API_BASE_URL, "api/v1/llm/generation/stream/text")
llm_params_1 = {
"system_prompt": system_prompt_1,
"temperature": temperature_1,
Expand Down Expand Up @@ -401,7 +402,7 @@ async def llm_text_completion_memory_stream_arena(
state.value['current_chat_id'] = chat_id

# Call the LLM model with streaming
llm_api_endpoint = API_BASE_URL + "api/v1/llm/generation/stream/text/memory"
llm_api_endpoint = urljoin(API_BASE_URL, "api/v1/llm/generation/stream/text/memory")
llm_params_1 = {
"system_prompt": system_prompt_1,
"temperature": temperature_1,
Expand Down
30 changes: 14 additions & 16 deletions frontend/gradio/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Reference: https://github.com/lm-sys/FastChat/fastchat/serve/gradio_web_server.py
from typing import List, Dict, Optional, Tuple, Any
import aiohttp
import requests
import gradio as gr
from backend.app.core.constants import API_BASE_URL
from urllib.parse import urljoin
from backend.app.core.constants import (
API_BASE_URL,
TEXT_API_QUOTA_LIMIT
Expand All @@ -23,7 +22,8 @@ def get_client_ip(request: gr.Request) -> str:
def fetch_available_models(llm_type: str) -> Dict[str, str]:
"""Fetch available large language models from the base API URL."""
models = {}
response = requests.get(API_BASE_URL + f"api/v1/llms/type/{llm_type}")
url = urljoin(API_BASE_URL, f"api/v1/llms/type/{llm_type}")
response = requests.get(url)
assert response.status_code == 200
model_info = response.json()

Expand All @@ -33,25 +33,24 @@ def fetch_available_models(llm_type: str) -> Dict[str, str]:
return models

async def start_chat(user_id: str, mode: str) -> str:
chat_api_endpoint = API_BASE_URL + "api/v1/chats/async"
url = urljoin(API_BASE_URL, "api/v1/chats/async")
chat_data = {"user_id": user_id, "mode": mode}

async with aiohttp.ClientSession() as session:
async with session.post(chat_api_endpoint, json=chat_data) as response:
async with session.post(url, json=chat_data) as response:
assert response.status == 200
data = await response.json()
return data["id"]


async def add_message_to_db(session: aiohttp.ClientSession, chat_id: str, content: str, message_type: str, origin: str):
message_api_endpoint = API_BASE_URL + "api/v1/messages/async"
url = urljoin(API_BASE_URL, "api/v1/messages/async")
message_data = {
"chat_id": chat_id,
"content": content,
"message_type": message_type,
"origin": origin
}
async with session.post(message_api_endpoint, json=message_data) as message_response:
async with session.post(url, json=message_data) as message_response:
assert message_response.status == 200

def load_terms_of_use_js() -> str:
Expand Down Expand Up @@ -118,21 +117,21 @@ def load_demo(url_params: gr.JSON, request: gr.Request) -> Tuple[Optional[Dict[s
client_ip = get_client_ip(request)

# Check if the client is first time user
user_ip_api_endpoint = base_api_url + f"api/v1/users/ip/{client_ip}"
user_ip_api_endpoint = urljoin(base_api_url, f"api/v1/users/ip/{client_ip}")
user_ip_response = requests.get(user_ip_api_endpoint)
first_time_user = user_ip_response.status_code != 200

if first_time_user:
# Create User Profile based on IP address
user_api_endpoint = base_api_url + "api/v1/users/"
user_api_endpoint = urljoin(base_api_url, "api/v1/users/")
user_data = {"ip_address": client_ip}
user_response = requests.post(user_api_endpoint, json=user_data)
user_response.raise_for_status()
user_id = user_response.json()["id"]
user_name = user_response.json()["name"]

# Create a new Quota(Text and Image API) for a the new user
quota_api_endpoint = base_api_url + "api/v1/quotas/"
quota_api_endpoint = urljoin(base_api_url, "api/v1/quotas/")
for llm_model_type in api_quota_map:
quota_limit, resource, _, _ = api_quota_map[llm_model_type]
quota_data = {"users": [{"id": user_id}], "quota_limit": quota_limit, "resource": resource}
Expand Down Expand Up @@ -162,7 +161,7 @@ def load_demo(url_params: gr.JSON, request: gr.Request) -> Tuple[Optional[Dict[s
* See more in the [Gradio](https://www.gradio.app/) 📖
## 👇 Choose any models for {task}:
""".format(user_name = user_name, task = task)
""".format(user_name=user_name, task=task)

markdown_update = gr.Markdown(notice_markdown)
state = gr.State({
Expand All @@ -174,10 +173,9 @@ def load_demo(url_params: gr.JSON, request: gr.Request) -> Tuple[Optional[Dict[s
"history": []
})
if is_arena_ui:
model_1_dropdown_update = gr.Dropdown(label= 'LLM', choices=model_names, value=selected_model_1, visible=True, show_label=False)
model_2_dropdown_update = gr.Dropdown(label= 'LLM', choices=model_names, value=selected_model_2, visible=True, show_label=False)
model_1_dropdown_update = gr.Dropdown(label='LLM', choices=model_names, value=selected_model_1, visible=True, show_label=False)
model_2_dropdown_update = gr.Dropdown(label='LLM', choices=model_names, value=selected_model_2, visible=True, show_label=False)
return state, markdown_update, model_1_dropdown_update, model_2_dropdown_update
else:
model_dropdown_update = gr.Dropdown(label= 'LLM', choices=model_names, value=selected_model_1, visible=True)
model_dropdown_update = gr.Dropdown(label='LLM', choices=model_names, value=selected_model_1, visible=True)
return state, markdown_update, model_dropdown_update

0 comments on commit 4ed84ec

Please sign in to comment.