Skip to content

Commit

Permalink
Put base url in client
Browse files Browse the repository at this point in the history
  • Loading branch information
Bobholamovic committed Dec 22, 2023
1 parent 09ab4cd commit 99a1de4
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 42 deletions.
7 changes: 2 additions & 5 deletions erniebot/src/erniebot/backends/aistudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,9 @@ def request(
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, Iterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand All @@ -86,7 +85,6 @@ def request(
headers=headers,
files=files,
request_timeout=request_timeout,
base_url=self.base_url,
)

async def arequest(
Expand All @@ -100,10 +98,9 @@ async def arequest(
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down
11 changes: 6 additions & 5 deletions erniebot/src/erniebot/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@ class EBBackend(object):

def __init__(self, config_dict: ConfigDictType) -> None:
super().__init__()

self.api_type = self.API_TYPE
self.base_url = config_dict.get("api_base_url", None) or self.BASE_URL
self._cfg = config_dict
self._client = EBClient(
session=self._cfg.get("requests_session", None), asession=self._cfg.get("aiohttp_session", None), response_handler=self.handle_response, proxy=self._cfg.get("proxy", None),)
self.base_url,
session=self._cfg.get("requests_session", None),
asession=self._cfg.get("aiohttp_session", None),
response_handler=self.handle_response,
proxy=self._cfg.get("proxy", None),
)

def handle_response(self, resp: EBResponse) -> EBResponse:
raise NotImplementedError
Expand Down Expand Up @@ -61,6 +65,3 @@ async def arequest(
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
raise NotImplementedError

def _get_url(self, path: str) -> str:
return f"{self.base_url}{path}"
12 changes: 4 additions & 8 deletions erniebot/src/erniebot/backends/bce.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ def request(
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, Iterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down Expand Up @@ -101,10 +100,9 @@ async def arequest(
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down Expand Up @@ -166,10 +164,9 @@ def request(
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, Iterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand All @@ -196,10 +193,9 @@ async def arequest(
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down
6 changes: 2 additions & 4 deletions erniebot/src/erniebot/backends/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@ def request(
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, Iterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand All @@ -73,10 +72,9 @@ async def arequest(
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down
37 changes: 17 additions & 20 deletions erniebot/src/erniebot/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@
import asyncio
import http
import json
import time
from contextlib import contextmanager, asynccontextmanager
from contextlib import asynccontextmanager, contextmanager
from json import JSONDecodeError
from typing import (
Any,
Expand All @@ -52,12 +51,12 @@
Callable,
ClassVar,
Dict,
Generator,
Iterator,
Mapping,
Optional,
Tuple,
Union,
Generator,
)

import aiohttp
Expand Down Expand Up @@ -86,12 +85,15 @@ class EBClient(object):

def __init__(
self,
base_url: str,
*,
session: Optional[requests.Session] = None,
asession: Optional[aiohttp.ClientSession] = None,
response_handler: Optional[Callable[[EBResponse], EBResponse]] = None,
proxy: Optional[str] = None,
) -> None:
super().__init__()
self._base_url = base_url
self._session = session
self._asession = asession
self._resp_handler = response_handler
Expand All @@ -100,13 +102,13 @@ def __init__(
def prepare_request(
self,
method: str,
url: str,
path: str,
supplied_headers: Optional[HeadersType],
params: Optional[ParamsType],
files: Optional[FilesType],
) -> Tuple[str, HeadersType, Optional[bytes]]:
url = f"{self._base_url}{path}"
headers = self._validate_headers(supplied_headers)

data = None
method = method.upper()
if method == "GET" or method == "DELETE":
Expand All @@ -118,7 +120,6 @@ def prepare_request(
headers["Content-Type"] = "application/json"
else:
raise errors.ConnectionError(f"Unrecognized HTTP method: {repr(method)}")

headers = self.get_request_headers(method, headers)

logging.debug("Method: %s", method)
Expand Down Expand Up @@ -185,7 +186,7 @@ def wrap_resp(resp: Iterator) -> Iterator[EBResponse]:
if should_clean_up_ctx:
# We don't care about the exception type and the stack trace.
ctx.__exit__(None, None, None)

return resp

async def asend_request(
Expand Down Expand Up @@ -223,6 +224,7 @@ async def asend_request(
f"but got a {'streamed' if got_stream else 'non-streamed'} response. "
)
if got_stream:

async def wrap_resp(resp: AsyncIterator) -> AsyncIterator[EBResponse]:
try:
async for r in resp:
Expand Down Expand Up @@ -281,7 +283,7 @@ def send_request_raw(
raise errors.TimeoutError(f"Request timed out: {e}") from e
except requests.exceptions.RequestException as e:
raise errors.ConnectionError(f"Error communicating with server: {e}") from e

return result

async def asend_request_raw(
Expand Down Expand Up @@ -381,8 +383,7 @@ def _interpret_response(
)

async def _interpret_async_response(
self,
response: aiohttp.ClientResponse
self, response: aiohttp.ClientResponse
) -> Tuple[Union[EBResponse, AsyncIterator[EBResponse]], bool]:
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith(
"text/event-stream"
Expand All @@ -407,19 +408,13 @@ async def _interpret_async_response(
False,
)

def _interpret_stream_response(
self,
response: requests.Response
) -> Iterator[EBResponse]:
def _interpret_stream_response(self, response: requests.Response) -> Iterator[EBResponse]:
for line in self._parse_stream(response.iter_lines()):
resp = self._interpret_response_line(
line, response.status_code, response.headers, stream=True
)
resp = self._interpret_response_line(line, response.status_code, response.headers, stream=True)
yield resp

async def _interpret_async_stream_response(
self,
response: aiohttp.ClientResponse
self, response: aiohttp.ClientResponse
) -> AsyncIterator[EBResponse]:
async for line in self._parse_async_stream(response.content):
resp = self._interpret_response_line(line, response.status, response.headers, stream=True)
Expand Down Expand Up @@ -497,7 +492,9 @@ def _make_requests_session_context_manager(self) -> Generator[requests.Session,
session.close()

@asynccontextmanager
async def _make_aiohttp_session_context_manager(self) -> AsyncGenerator[aiohttp.ClientSession, None, None]:
async def _make_aiohttp_session_context_manager(
self,
) -> AsyncGenerator[aiohttp.ClientSession, None]:
# TODO: Support proxies
if self._asession is not None:
session = self._asession
Expand Down

0 comments on commit 99a1de4

Please sign in to comment.