diff --git a/CHANGELOG.md b/CHANGELOG.md index 3887264bbb..d59161e03c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). Nothing unreleased! +## [0.7.600rc0] - 2023-11-08 + +### Changed + +- Replaced aiohttp with httpx +- Prompt Playground has been updated to work with the new openai release (v1). Including tools. + ## [0.7.500] - 2023-11-07 ### Added diff --git a/backend/chainlit/client/cloud.py b/backend/chainlit/client/cloud.py index 50d36c3471..6e89a1729f 100644 --- a/backend/chainlit/client/cloud.py +++ b/backend/chainlit/client/cloud.py @@ -1,7 +1,7 @@ import uuid from typing import Any, Dict, List, Optional, Union -import aiohttp +import httpx from chainlit.logger import logger from .base import ( @@ -459,41 +459,37 @@ async def upload_element( if conversation_id: body["conversationId"] = conversation_id - path = f"/api/upload/file" + path = "/api/upload/file" - async with aiohttp.ClientSession() as session: - async with session.post( + async with httpx.AsyncClient() as client: + response = await client.post( f"{self.chainlit_server}{path}", json=body, headers=self.headers, - ) as r: - if not r.ok: - reason = await r.text() - logger.error(f"Failed to sign upload url: {reason}") - return {"object_key": None, "url": None} - json_res = await r.json() + ) + if response.status_code != 200: + reason = response.text + logger.error(f"Failed to sign upload url: {reason}") + return {"object_key": None, "url": None} + json_res = response.json() upload_details = json_res["post"] object_key = upload_details["fields"]["key"] signed_url = json_res["signedUrl"] - form_data = aiohttp.FormData() + # Prepare form data + form_data = upload_details["fields"].copy() + form_data["file"] = (id, content, "multipart/form-data") - # Add fields to the form_data - for field_name, field_value in upload_details["fields"].items(): - form_data.add_field(field_name, field_value) - - # Add file to the form_data - form_data.add_field("file", content, content_type="multipart/form-data") - async with aiohttp.ClientSession() as session: - async with session.post( + async with httpx.AsyncClient() as client: + upload_response = await client.post( upload_details["url"], - data=form_data, - ) as upload_response: - if not upload_response.ok: - reason = await upload_response.text() - logger.error(f"Failed to upload file: {reason}") - return {"object_key": None, "url": None} - + files=form_data, + ) + try: + upload_response.raise_for_status() url = f'{upload_details["url"]}/{object_key}' return {"object_key": object_key, "url": signed_url} + except Exception as e: + logger.error(f"Failed to upload file: {str(e)}") + return {"object_key": None, "url": None} diff --git a/backend/chainlit/langflow/__init__.py b/backend/chainlit/langflow/__init__.py index 9d314308ec..ee8094fc0b 100644 --- a/backend/chainlit/langflow/__init__.py +++ b/backend/chainlit/langflow/__init__.py @@ -7,7 +7,7 @@ from typing import Dict, Optional, Union -import aiohttp +import httpx from chainlit.telemetry import trace_event @@ -16,15 +16,12 @@ async def load_flow(schema: Union[Dict, str], tweaks: Optional[Dict] = None): trace_event("load_langflow") - if type(schema) == str: - async with aiohttp.ClientSession() as session: - async with session.get( - schema, - ) as r: - if not r.ok: - reason = await r.text() - raise ValueError(f"Error: {reason}") - schema = await r.json() + if isinstance(schema, str): + async with httpx.AsyncClient() as client: + response = await client.get(schema) + if response.status_code != 200: + raise ValueError(f"Error: {response.text}") + schema = response.json() flow = load_flow_from_json(flow=schema, tweaks=tweaks) diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py index 3d3b8ab7a1..1c73aac88a 100644 --- a/backend/chainlit/oauth_providers.py +++ b/backend/chainlit/oauth_providers.py @@ -3,7 +3,7 @@ import urllib.parse from typing import Dict, List, Optional, Tuple -import aiohttp +import httpx from chainlit.client.base import AppUser from fastapi import HTTPException @@ -44,46 +44,44 @@ async def get_token(self, code: str, url: str): "client_secret": self.client_secret, "code": code, } - async with aiohttp.ClientSession( - trust_env=True, raise_for_status=True - ) as session: - async with session.post( + async with httpx.AsyncClient() as client: + response = await client.post( "https://github.com/login/oauth/access_token", - json=payload, - ) as result: - text = await result.text() - content = urllib.parse.parse_qs(text) - token = content.get("access_token", [""])[0] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token + data=payload, + ) + response.raise_for_status() + content = urllib.parse.parse_qs(response.text) + token = content.get("access_token", [""])[0] + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token async def get_user_info(self, token: str): - async with aiohttp.ClientSession( - trust_env=True, raise_for_status=True - ) as session: - async with session.get( + async with httpx.AsyncClient() as client: + user_response = await client.get( "https://api.github.com/user", headers={"Authorization": f"token {token}"}, - ) as result: - user = await result.json() + ) + user_response.raise_for_status() + user = user_response.json() - async with session.get( - "https://api.github.com/user/emails", - headers={"Authorization": f"token {token}"}, - ) as email_result: - emails = await email_result.json() + emails_response = await client.get( + "https://api.github.com/user/emails", + headers={"Authorization": f"token {token}"}, + ) + emails_response.raise_for_status() + emails = emails_response.json() - user.update({"emails": emails}) + user.update({"emails": emails}) - app_user = AppUser( - username=user["login"], - image=user["avatar_url"], - provider="github", - ) - return (user, app_user) + app_user = AppUser( + username=user["login"], + image=user["avatar_url"], + provider="github", + ) + return (user, app_user) class GoogleOAuthProvider(OAuthProvider): @@ -108,35 +106,35 @@ async def get_token(self, code: str, url: str): "grant_type": "authorization_code", "redirect_uri": url, } - async with aiohttp.ClientSession( - trust_env=True, raise_for_status=True - ) as session: - async with session.post( + async with httpx.AsyncClient() as client: + response = await client.post( "https://oauth2.googleapis.com/token", data=payload, - ) as result: - json = await result.json() - token = json["access_token"] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token + ) + response.raise_for_status() + json = response.json() + token = json.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token async def get_user_info(self, token: str): - async with aiohttp.ClientSession( - trust_env=True, raise_for_status=True - ) as session: - async with session.get( + async with httpx.AsyncClient() as client: + response = await client.get( "https://www.googleapis.com/userinfo/v2/me", headers={"Authorization": f"Bearer {token}"}, - ) as result: - user = await result.json() + ) + response.raise_for_status() + user = response.json() - app_user = AppUser( - username=user["name"], image=user["picture"], provider="google" - ) - return (user, app_user) + app_user = AppUser( + username=user["name"], image=user["picture"], provider="google" + ) + return (user, app_user) class AzureADOAuthProvider(OAuthProvider): @@ -175,52 +173,51 @@ async def get_token(self, code: str, url: str): "grant_type": "authorization_code", "redirect_uri": url, } - async with aiohttp.ClientSession( - trust_env=True, raise_for_status=True - ) as session: - async with session.post( + async with httpx.AsyncClient() as client: + response = await client.post( self.token_url, data=payload, - ) as result: - json = await result.json() - - token = json["access_token"] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token + ) + response.raise_for_status() + json = response.json() + + token = json["access_token"] + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token async def get_user_info(self, token: str): - async with aiohttp.ClientSession( - trust_env=True, raise_for_status=True - ) as session: - async with session.get( + async with httpx.AsyncClient() as client: + response = await client.get( "https://graph.microsoft.com/v1.0/me", headers={"Authorization": f"Bearer {token}"}, - ) as result: - user = await result.json() - - try: - async with session.get( - "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", - headers={"Authorization": f"Bearer {token}"}, - ) as photo_result: - photo_data = await photo_result.read() - base64_image = base64.b64encode(photo_data) - user[ - "image" - ] = f"data:{photo_result.content_type};base64,{base64_image.decode('utf-8')}" - except Exception as e: - # Ignore errors getting the photo - pass - - app_user = AppUser( - username=user["userPrincipalName"], - image=user.get("image", ""), - provider="azure-ad", + ) + response.raise_for_status() + + user = response.json() + + try: + photo_response = await client.get( + "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", + headers={"Authorization": f"Bearer {token}"}, ) - return (user, app_user) + photo_data = await photo_response.aread() + base64_image = base64.b64encode(photo_data) + user[ + "image" + ] = f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" + except Exception as e: + # Ignore errors getting the photo + pass + + app_user = AppUser( + username=user["userPrincipalName"], + image=user.get("image", ""), + provider="azure-ad", + ) + return (user, app_user) class OktaOAuthProvider(OAuthProvider): @@ -263,36 +260,34 @@ async def get_token(self, code: str, url: str): "grant_type": "authorization_code", "redirect_uri": url, } - async with aiohttp.ClientSession( - trust_env=True, raise_for_status=True - ) as session: - async with session.post( + async with httpx.AsyncClient() as client: + response = await client.post( f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/token", data=payload, - ) as result: - json = await result.json() - - token = json["access_token"] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token + ) + response.raise_for_status() + json_data = response.json() + + token = json_data.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token async def get_user_info(self, token: str): - async with aiohttp.ClientSession( - trust_env=True, raise_for_status=True - ) as session: - async with session.get( + async with httpx.AsyncClient() as client: + response = await client.get( f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/userinfo", headers={"Authorization": f"Bearer {token}"}, - ) as result: - user = await result.json() + ) + response.raise_for_status() + user = response.json() - app_user = AppUser( - username=user.get("email"), image="", provider="okta" - ) - return (user, app_user) + app_user = AppUser(username=user.get("email"), image="", provider="okta") + return (user, app_user) class Auth0OAuthProvider(OAuthProvider): @@ -321,36 +316,34 @@ async def get_token(self, code: str, url: str): "grant_type": "authorization_code", "redirect_uri": url, } - async with aiohttp.ClientSession( - trust_env=True, raise_for_status=True - ) as session: - async with session.post( + async with httpx.AsyncClient() as client: + response = await client.post( f"{self.domain}/oauth/token", - json=payload, - ) as result: - json_content = await result.json() - token = json_content.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token + data=payload, + ) + response.raise_for_status() + json_content = response.json() + token = json_content.get("access_token") + if not token: + raise HTTPException( + status_code=400, detail="Failed to get the access token" + ) + return token async def get_user_info(self, token: str): - async with aiohttp.ClientSession( - trust_env=True, raise_for_status=True - ) as session: - async with session.get( + async with httpx.AsyncClient() as client: + response = await client.get( f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"}, - ) as result: - user = await result.json() - app_user = AppUser( - username=user.get("email"), - image=user.get("picture", ""), - provider="auth0", - ) - return (user, app_user) + ) + response.raise_for_status() + user = response.json() + app_user = AppUser( + username=user.get("email"), + image=user.get("picture", ""), + provider="auth0", + ) + return (user, app_user) class DescopeOAuthProvider(OAuthProvider): @@ -378,31 +371,32 @@ async def get_token(self, code: str, url: str): "grant_type": "authorization_code", "redirect_uri": url, } - async with aiohttp.ClientSession(raise_for_status=True) as session: - async with session.post( + async with httpx.AsyncClient() as client: + response = await client.post( f"{self.domain}/token", - json=payload, - ) as result: - json_content = await result.json() - token = json_content.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token + data=payload, + ) + response.raise_for_status() + json_content = response.json() + token = json_content.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token async def get_user_info(self, token: str): - async with aiohttp.ClientSession(raise_for_status=True) as session: - async with session.get( - f"{self.domain}/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) as result: - user = await result.json() - - app_user = AppUser( - username=user.get("email"), image="", provider="descope" - ) - return (user, app_user) + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"} + ) + response.raise_for_status() # This will raise an exception for 4xx/5xx responses + user = response.json() + + app_user = AppUser(username=user.get("email"), image="", provider="descope") + return (user, app_user) providers = [ diff --git a/backend/chainlit/playground/providers/openai.py b/backend/chainlit/playground/providers/openai.py index 975368991c..309c4132b1 100644 --- a/backend/chainlit/playground/providers/openai.py +++ b/backend/chainlit/playground/providers/openai.py @@ -1,6 +1,5 @@ import json from contextlib import contextmanager -from typing import Dict from chainlit.input_widget import Select, Slider, Tags from chainlit.playground.provider import BaseProvider @@ -8,8 +7,15 @@ from fastapi.responses import StreamingResponse -def stringify_function_call(function_call: Dict): - _function_call = function_call.copy() +def stringify_function_call(function_call): + if isinstance(function_call, dict): + _function_call = function_call.copy() + else: + _function_call = { + "arguments": function_call.arguments, + "name": function_call.name, + } + if "arguments" in _function_call and isinstance(_function_call["arguments"], str): _function_call["arguments"] = json.loads(_function_call["arguments"]) return json.dumps(_function_call, indent=4, ensure_ascii=False) @@ -76,37 +82,32 @@ def handle_openai_error(): try: yield - except openai.error.Timeout as e: + except openai.APITimeoutError as e: raise HTTPException( status_code=408, detail=f"OpenAI API request timed out: {e}", ) - except openai.error.APIError as e: + except openai.APIError as e: raise HTTPException( status_code=500, detail=f"OpenAI API returned an API Error: {e}", ) - except openai.error.APIConnectionError as e: + except openai.APIConnectionError as e: raise HTTPException( status_code=503, detail=f"OpenAI API request failed to connect: {e}", ) - except openai.error.InvalidRequestError as e: - raise HTTPException( - status_code=400, - detail=f"OpenAI API request was invalid: {e}", - ) - except openai.error.AuthenticationError as e: + except openai.AuthenticationError as e: raise HTTPException( status_code=403, detail=f"OpenAI API request was not authorized: {e}", ) - except openai.error.PermissionError as e: + except openai.PermissionDeniedError as e: raise HTTPException( status_code=403, detail=f"OpenAI API request was not permitted: {e}", ) - except openai.error.RateLimitError as e: + except openai.RateLimitError as e: raise HTTPException( status_code=429, detail=f"OpenAI API request exceeded rate limit: {e}", @@ -120,13 +121,11 @@ def format_message(self, message, prompt): async def create_completion(self, request): await super().create_completion(request) - import openai + from openai import AsyncClient env_settings = self.validate_env(request=request) - deployment_id = self.get_var(request, "OPENAI_API_DEPLOYMENT_ID") - if deployment_id: - env_settings["deployment_id"] = deployment_id + client = AsyncClient(api_key=env_settings["api_key"]) llm_settings = request.prompt.settings @@ -150,8 +149,7 @@ async def create_completion(self, request): llm_settings["stream"] = True with handle_openai_error(): - response = await openai.ChatCompletion.acreate( - **env_settings, + response = await client.chat.completions.create( messages=messages, **llm_settings, ) @@ -159,24 +157,20 @@ async def create_completion(self, request): if llm_settings["stream"]: async def create_event_stream(): - async for stream_resp in response: - if hasattr(stream_resp, "choices") and len(stream_resp.choices) > 0: - delta = stream_resp.choices[0]["delta"] - token = delta.get("content", "") - if token: - yield token + async for part in response: + if token := part.choices[0].delta.content or "": + yield token else: continue else: async def create_event_stream(): - message = response.choices[0]["message"] - function_call = message.get("function_call") - if function_call: + message = response.choices[0].message + if function_call := message.function_call: yield stringify_function_call(function_call) else: - yield message.get("content", "") + yield message.content or "" return StreamingResponse(create_event_stream()) @@ -187,15 +181,63 @@ def message_to_string(self, message): async def create_completion(self, request): await super().create_completion(request) - import openai + from openai import AsyncClient env_settings = self.validate_env(request=request) - deployment_id = self.get_var(request, "OPENAI_API_DEPLOYMENT_ID") + client = AsyncClient(api_key=env_settings["api_key"]) + + llm_settings = request.prompt.settings + + self.require_settings(llm_settings) + + prompt = self.create_prompt(request) + + if "stop" in llm_settings: + stop = llm_settings["stop"] + + # OpenAI doesn't support an empty stop array, clear it + if isinstance(stop, list) and len(stop) == 0: + stop = None + + llm_settings["stop"] = stop + + llm_settings["stream"] = True + + with handle_openai_error(): + response = await client.completions.create( + prompt=prompt, + **llm_settings, + ) - if deployment_id: - env_settings["deployment_id"] = deployment_id + async def create_event_stream(): + async for part in response: + if token := part.choices[0].text or "": + yield token + else: + continue + return StreamingResponse(create_event_stream()) + + +class AzureOpenAIProvider(BaseProvider): + def message_to_string(self, message): + return message.to_string() + + async def create_completion(self, request): + await super().create_completion(request) + from openai import AsyncAzureOpenAI + + env_settings = self.validate_env(request=request) + + client = AsyncAzureOpenAI( + api_key=env_settings["api_key"], + api_version=env_settings["api_version"], + azure_endpoint=env_settings["azure_endpoint"], + azure_ad_token=self.get_var(request, "AZURE_AD_TOKEN"), + azure_ad_token_provider=self.get_var(request, "AZURE_AD_TOKEN_PROVIDER"), + azure_deployment=self.get_var(request, "AZURE_DEPLOYMENT"), + ) llm_settings = request.prompt.settings self.require_settings(llm_settings) @@ -214,16 +256,87 @@ async def create_completion(self, request): llm_settings["stream"] = True with handle_openai_error(): - response = await openai.Completion.acreate( - **env_settings, + response = await client.completions.create( prompt=prompt, **llm_settings, ) async def create_event_stream(): - async for stream_resp in response: - token = stream_resp.get("choices")[0].get("text") - yield token + async for part in response: + if token := part.choices[0].text or "": + yield token + else: + continue + + return StreamingResponse(create_event_stream()) + + +class AzureChatOpenAIProvider(BaseProvider): + def format_message(self, message, prompt): + message = super().format_message(message, prompt) + return message.to_openai() + + async def create_completion(self, request): + await super().create_completion(request) + from openai import AsyncAzureOpenAI + + env_settings = self.validate_env(request=request) + + client = AsyncAzureOpenAI( + api_key=env_settings["api_key"], + api_version=env_settings["api_version"], + azure_endpoint=env_settings["azure_endpoint"], + azure_ad_token=self.get_var(request, "AZURE_AD_TOKEN"), + azure_ad_token_provider=self.get_var(request, "AZURE_AD_TOKEN_PROVIDER"), + azure_deployment=self.get_var(request, "AZURE_DEPLOYMENT"), + ) + + llm_settings = request.prompt.settings + + self.require_settings(llm_settings) + + messages = self.create_prompt(request) + + if "stop" in llm_settings: + stop = llm_settings["stop"] + + # OpenAI doesn't support an empty stop array, clear it + if isinstance(stop, list) and len(stop) == 0: + stop = None + + llm_settings["stop"] = stop + + llm_settings["model"] = env_settings["deployment_name"] + + if request.prompt.functions: + llm_settings["functions"] = request.prompt.functions + llm_settings["stream"] = False + else: + llm_settings["stream"] = True + + with handle_openai_error(): + response = await client.chat.completions.create( + messages=messages, + **llm_settings, + ) + + if llm_settings["stream"]: + + async def create_event_stream(): + async for part in response: + if token := part.choices[0].delta.content or "": + yield token + else: + continue + + else: + + async def create_event_stream(): + message = response.choices[0].message + if function_call := message.function_call: + yield stringify_function_call(function_call) + else: + yield message.content or "" return StreamingResponse(create_event_stream()) @@ -231,10 +344,10 @@ async def create_event_stream(): openai_env_vars = {"api_key": "OPENAI_API_KEY"} azure_openai_env_vars = { - "api_key": "OPENAI_API_KEY", - "api_type": "OPENAI_API_TYPE", - "api_base": "OPENAI_API_BASE", - "api_version": "OPENAI_API_VERSION", + "api_key": "AZURE_OPENAI_API_KEY", + "api_version": "AZURE_OPENAI_API_VERSION", + "azure_endpoint": "AZURE_OPENAI_ENDPOINT", + "deployment_name": "AZURE_OPENAI_DEPLOYMENT_NAME", } ChatOpenAI = ChatOpenAIProvider( @@ -253,15 +366,6 @@ async def create_event_stream(): is_chat=True, ) - -AzureChatOpenAI = ChatOpenAIProvider( - id="azure-openai-chat", - env_vars=azure_openai_env_vars, - name="AzureChatOpenAI", - inputs=openai_common_inputs, - is_chat=True, -) - OpenAI = OpenAIProvider( id="openai", name="OpenAI", @@ -278,7 +382,16 @@ async def create_event_stream(): is_chat=False, ) -AzureOpenAI = OpenAIProvider( + +AzureChatOpenAI = AzureChatOpenAIProvider( + id="azure-openai-chat", + env_vars=azure_openai_env_vars, + name="AzureChatOpenAI", + inputs=openai_common_inputs, + is_chat=True, +) + +AzureOpenAI = AzureOpenAIProvider( id="azure", name="AzureOpenAI", env_vars=azure_openai_env_vars, diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 79aac6368e..06c06f253c 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "chainlit" -version = "0.7.501" +version = "0.7.600rc0" keywords = ['LLM', 'Agents', 'gen ai', 'chat ui', 'chatbot ui', 'langchain'] description = "A faster way to build chatbot UIs." authors = ["Chainlit"] @@ -19,12 +19,12 @@ include = [ chainlit = 'chainlit.cli:cli' [tool.poetry.dependencies] -python = "^3.8.1" +python = ">=3.8.1,<3.12" dataclasses_json = "^0.5.7" uvicorn = "^0.23.2" fastapi = "^0.100" fastapi-socketio = "^0.0.10" -aiohttp = "^3.8.4" +httpx = "^0.25.1" aiofiles = "^23.1.0" syncer = "^2.0.3" asyncer = "^0.0.2" @@ -46,9 +46,9 @@ pyjwt = "^2.8.0" optional = true [tool.poetry.group.tests.dependencies] -openai = "^0.27.7" -langchain = "^0.0.267" -llama-index = "^0.8.3" +openai = ">=1.1.0" +langchain = "^0.0.331" +llama-index = "^0.8.64" transformers = "^4.30.1" matplotlib = "3.7.1" farm-haystack = "^1.18.0" diff --git a/frontend/src/components/organisms/chat/Messages/container.tsx b/frontend/src/components/organisms/chat/Messages/container.tsx index e53d240e13..94d927630c 100644 --- a/frontend/src/components/organisms/chat/Messages/container.tsx +++ b/frontend/src/components/organisms/chat/Messages/container.tsx @@ -11,7 +11,8 @@ import { IAvatarElement, IFunction, IMessage, - IMessageElement + IMessageElement, + ITool } from '@chainlit/components'; import { playgroundState } from 'state/playground'; @@ -57,25 +58,36 @@ const MessageContainer = memo( const onPlaygroundButtonClick = useCallback( (message: IMessage) => { - setPlayground((old) => ({ - ...old, - prompt: message.prompt - ? { - ...message.prompt, - functions: - (message.prompt.settings - ?.functions as unknown as IFunction[]) || [] - } - : undefined, - originalPrompt: message.prompt - ? { - ...message.prompt, - functions: - (message.prompt.settings - ?.functions as unknown as IFunction[]) || [] - } - : undefined - })); + setPlayground((old) => { + let functions = + (message.prompt?.settings?.functions as unknown as IFunction[]) || + []; + const tools = + (message.prompt?.settings?.tools as unknown as ITool[]) || []; + if (tools.length) { + functions = [ + ...functions, + ...tools + .filter((t) => t.type === 'function') + .map((t) => t.function) + ]; + } + return { + ...old, + prompt: message.prompt + ? { + ...message.prompt, + functions + } + : undefined, + originalPrompt: message.prompt + ? { + ...message.prompt, + functions + } + : undefined + }; + }); }, [setPlayground] ); diff --git a/libs/components/src/inputs/selects/SelectInput.tsx b/libs/components/src/inputs/selects/SelectInput.tsx index 613b54f1e4..14db1164c9 100644 --- a/libs/components/src/inputs/selects/SelectInput.tsx +++ b/libs/components/src/inputs/selects/SelectInput.tsx @@ -91,7 +91,7 @@ const SelectInput = ({ id: id, name: name || id, sx: { - color: grey[600], + color: 'text.primary', fontSize: '14px', fontWeight: 400, px: '16px', diff --git a/libs/components/src/types/playground.ts b/libs/components/src/types/playground.ts index 93e2c00e3e..68cc04f496 100644 --- a/libs/components/src/types/playground.ts +++ b/libs/components/src/types/playground.ts @@ -32,6 +32,11 @@ export interface IFunction { }; } +export interface ITool { + type: string; + function: IFunction; +} + export type PromptMode = 'Template' | 'Formatted'; export interface IPlayground {