Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
trisongz committed Feb 2, 2024
1 parent 1bc9ab0 commit 25d69d3
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 15 deletions.
13 changes: 7 additions & 6 deletions async_openai/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, **kwargs):
self.no_proxy_client_names: Optional[List[str]] = []
self.client_callbacks: Optional[List[Callable]] = []
self.functions: FunctionManager = OpenAIFunctions
self.ctx = ModelContextHandler
if self.auto_loadbalance_clients is None: self.auto_loadbalance_clients = self.settings.auto_loadbalance_clients
if self.auto_healthcheck is None: self.auto_healthcheck = self.settings.auto_healthcheck

Expand Down Expand Up @@ -1091,18 +1092,18 @@ def create_embeddings(
model = model,
**kwargs
)
if strip_newlines: inputs = [i.replace('\n', ' ') for i in inputs]
if strip_newlines: inputs = [i.replace('\n', ' ').strip() for i in inputs]
client = self.get_client(model = model, **kwargs)
if not client.is_azure:
response = client.embeddings.create(input = inputs, auto_retry = auto_retry, **kwargs)
response = client.embeddings.create(input = inputs, model = model, auto_retry = auto_retry, **kwargs)
return response.embeddings

embeddings = []
# We need to split into batches of 5 for Azure
# Azure has a limit of 5 inputs per request
batches = split_into_batches(inputs, 5)
for batch in batches:
response = client.embeddings.create(input = batch, auto_retry = auto_retry, **kwargs)
response = client.embeddings.create(input = batch, model = model, auto_retry = auto_retry, **kwargs)
embeddings.extend(response.embeddings)
# Shuffle the clients to load balance
client = self.get_client(model = model, azure_required = True, **kwargs)
Expand Down Expand Up @@ -1134,18 +1135,18 @@ async def async_create_embeddings(
model = model,
**kwargs
)
if strip_newlines: inputs = [i.replace('\n', ' ') for i in inputs]
if strip_newlines: inputs = [i.replace('\n', ' ').strip() for i in inputs]
client = self.get_client(model = model, **kwargs)
if not client.is_azure:
response = await client.embeddings.async_create(input = inputs, auto_retry = auto_retry, **kwargs)
response = await client.embeddings.async_create(input = inputs, model = model, auto_retry = auto_retry, **kwargs)
return response.embeddings

embeddings = []
# We need to split into batches of 5 for Azure
# Azure has a limit of 5 inputs per request
batches = split_into_batches(inputs, 5)
for batch in batches:
response = await client.embeddings.async_create(input = batch, auto_retry = auto_retry, **kwargs)
response = await client.embeddings.async_create(input = batch, model = model, auto_retry = auto_retry, **kwargs)
embeddings.extend(response.embeddings)
# Shuffle the clients to load balance
client = self.get_client(model = model, azure_required = True, **kwargs)
Expand Down
8 changes: 6 additions & 2 deletions async_openai/types/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,13 @@ def get_tokenizer(cls, name: str) -> Optional[tiktoken.Encoding]:
Gets the tokenizer
"""
# Switch the 35 -> 3.5
if '35' in name: name = name.replace('35', '3.5')
if '35' in name: name = name.replace('35', '3.5')
if name not in cls.tokenizers:
cls.tokenizers[name] = tiktoken.encoding_for_model(name)
if name in {'text-embedding-3-small', 'text-embedding-3-large'}:
enc_name = 'cl100k_base'
cls.tokenizers[name] = tiktoken.get_encoding(enc_name)
else:
cls.tokenizers[name] = tiktoken.encoding_for_model(name)
return cls.tokenizers[name]

def count_chat_tokens(
Expand Down
145 changes: 139 additions & 6 deletions async_openai/types/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import jinja2
import functools
from abc import ABC
from pydantic import PrivateAttr, BaseModel
# from lazyops.types import BaseModel
Expand All @@ -13,10 +14,11 @@
from async_openai.utils.fixjson import resolve_json
from . import errors

from typing import Optional, Any, Dict, List, Union, Type, Tuple, Awaitable, TypeVar, TYPE_CHECKING
from typing import Optional, Any, Dict, List, Union, Type, Tuple, Awaitable, Generator, AsyncGenerator, TypeVar, TYPE_CHECKING

if TYPE_CHECKING:
from async_openai import ChatResponse, ChatRoute
from async_openai.types.resources import Usage
from async_openai.manager import OpenAIManager as OpenAISessionManager
from lazyops.utils.logs import Logger
from lazyops.libs.persistence import PersistentDict
Expand All @@ -28,6 +30,10 @@

class BaseFunctionModel(BaseModel):
_name: Optional[str] = PrivateAttr(None)
if TYPE_CHECKING:
usage: Optional[Usage]
else:
usage: Optional[Any] = None

def update(
self,
Expand Down Expand Up @@ -107,6 +113,15 @@ class BaseFunction(ABC):
default_model_develop: Optional[str] = None
default_model_production: Optional[str] = None

auto_register_function: Optional[bool] = True

def __init_subclass__(cls, **kwargs):
"""
Subclass Hook
"""
if cls.auto_register_function:
OpenAIFunctions.register_function(cls, initialize = False)


def __init__(
self,
Expand Down Expand Up @@ -431,7 +446,7 @@ def parse_response(
response: 'ChatResponse',
schema: Optional[Type[FunctionSchemaT]] = None,
include_name: Optional[bool] = True,
) -> Optional[FunctionSchemaT]:
) -> Optional[FunctionSchemaT]: # sourcery skip: extract-duplicate-method
"""
Parses the response
"""
Expand All @@ -440,13 +455,15 @@ def parse_response(
result = schema.model_validate(response.function_results[0].arguments, from_attributes = True)
if include_name:
result._name = self.name
result.usage = response.usage
return result
except Exception as e:
self.autologger.error(f"[{self.name} - {response.model} - {response.usage}] Failed to parse object: {e}\n{response.text}\n{response.function_results[0].arguments}")
try:
result = schema.model_validate(resolve_json(response.function_results[0].arguments), from_attributes = True)
if include_name:
result._name = self.name
result.usage = response.usage
return result
except Exception as e:
self.autologger.error(f"[{self.name} - {response.model} - {response.usage}] Failed to parse object after fixing")
Expand Down Expand Up @@ -624,7 +641,12 @@ def run_function_loop(
if result is not None: return result
attempts += 1
self.autologger.error(f"Unable to parse the response for {self.name} after {self.max_attempts} attempts.")
if raise_errors: raise errors.MaxRetriesExhausted(name = self.name, attempts = self.max_attempts)
if raise_errors: raise errors.MaxRetriesExhausted(
name = self.name,
func_name = self.name,
model = model,
attempts = self.max_attempts,
)
return None

async def arun_function_loop(
Expand Down Expand Up @@ -663,7 +685,12 @@ async def arun_function_loop(
if result is not None: return result
attempts += 1
self.autologger.error(f"Unable to parse the response for {self.name} after {self.max_attempts} attempts.")
if raise_errors: raise errors.MaxRetriesExhausted(name = self.name, attempts = self.max_attempts)
if raise_errors: raise errors.MaxRetriesExhausted(
name = self.name,
func_name = self.name,
model = model,
attempts = self.max_attempts,
)
return None


Expand Down Expand Up @@ -760,6 +787,7 @@ def register_function(
name: Optional[str] = None,
overwrite: Optional[bool] = False,
raise_error: Optional[bool] = False,
initialize: Optional[bool] = True,
**kwargs,
):
"""
Expand All @@ -768,7 +796,7 @@ def register_function(
if isinstance(func, str):
from lazyops.utils.lazy import lazy_import
func = lazy_import(func)
if isinstance(func, type):
if isinstance(func, type) and initialize:
func = func(**kwargs)
name = name or func.name
if not overwrite and name in self.functions:
Expand All @@ -789,11 +817,23 @@ async def acreate_hash(self, **kwargs) -> str:
"""
return await self.api.pooler.asyncish(self.create_hash, **kwargs)

def _get_function(self, name: str) -> Optional[BaseFunction]:
"""
Gets the function
"""
func = self.functions.get(name)
if not func: return None
if isinstance(func, type):
func = func(**self._kwargs)
self.functions[name] = func
return func

def get(self, name: Union[str, 'FunctionT']) -> Optional['FunctionT']:
"""
Gets the function
"""
return name if isinstance(name, BaseFunction) else self.functions.get(name)
return name if isinstance(name, BaseFunction) else self._get_function(name)


def execute(
self,
Expand Down Expand Up @@ -918,6 +958,96 @@ def check_value_present(
return True
return False

def map(
self,
function: Union['FunctionT', str],
iterable_kwargs: List[Dict[str, Any]],
*args,
cachable: Optional[bool] = True,
overrides: Optional[List[str]] = None,
return_ordered: Optional[bool] = True,
**function_kwargs
) -> List[Optional['FunctionSchemaT']]:
"""
Maps the function to the iterable in parallel
"""
partial = functools.partial(
self.execute,
function,
cachable = cachable,
overrides = overrides,
**function_kwargs
)
return self.api.pooler.map(partial, iterable_kwargs, *args, return_ordered = return_ordered)

async def amap(
self,
function: Union['FunctionT', str],
iterable_kwargs: List[Dict[str, Any]],
*args,
cachable: Optional[bool] = True,
overrides: Optional[List[str]] = None,
return_ordered: Optional[bool] = True,
concurrency_limit: Optional[int] = None,
**function_kwargs
) -> List[Optional['FunctionSchemaT']]:
"""
Maps the function to the iterable in parallel
"""
partial = functools.partial(
self.aexecute,
function,
cachable = cachable,
overrides = overrides,
**function_kwargs
)
return await self.api.pooler.amap(partial, iterable_kwargs, *args, return_ordered = return_ordered, concurrency_limit = concurrency_limit)

def iterate(
self,
function: Union['FunctionT', str],
iterable_kwargs: List[Dict[str, Any]],
*args,
cachable: Optional[bool] = True,
overrides: Optional[List[str]] = None,
return_ordered: Optional[bool] = False,
**function_kwargs
) -> Generator[Optional['FunctionSchemaT'], None, None]:
"""
Maps the function to the iterable in parallel
"""
partial = functools.partial(
self.execute,
function,
cachable = cachable,
overrides = overrides,
**function_kwargs
)
return self.api.pooler.iterate(partial, iterable_kwargs, *args, return_ordered = return_ordered)

def aiterate(
self,
function: Union['FunctionT', str],
iterable_kwargs: List[Dict[str, Any]],
*args,
cachable: Optional[bool] = True,
overrides: Optional[List[str]] = None,
return_ordered: Optional[bool] = False,
concurrency_limit: Optional[int] = None,
**function_kwargs
) -> AsyncGenerator[Optional['FunctionSchemaT'], None]:
"""
Maps the function to the iterable in parallel
"""
partial = functools.partial(
self.aexecute,
function,
cachable = cachable,
overrides = overrides,
**function_kwargs
)
return self.api.pooler.aiterate(partial, iterable_kwargs, *args, return_ordered = return_ordered, concurrency_limit = concurrency_limit)

def __call__(
self,
function: Union['FunctionT', str],
Expand All @@ -940,6 +1070,9 @@ def __call__(
overrides = overrides,
**function_kwargs
)





OpenAIFunctions: FunctionManager = ProxyObject(FunctionManager)
61 changes: 61 additions & 0 deletions async_openai/utils/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

"""
Embedding Utility Helpers
These are borrowed from the `openai` experiments library
- We specifically use lazy loading to avoid runtime errors if the user does not have the required dependencies
"""

from lazyops.libs import lazyload as lz
from lazyops.libs.proxyobj import ProxyObject
from lazyops.types.common import Literal
from typing import Dict, Callable, List, Union, Optional

if lz.TYPE_CHECKING:
from scipy import spatial
import numpy as np
from numpy import ndarray
else:
spatial = lz.LazyLoad("scipy.spatial")
np = lz.LazyLoad("numpy")

def _initialize_distance_dict(*args, **kwargs) -> Dict[str, Callable[..., float]]:
"""
Initializes the distance dictionary
"""
return {
"cosine": spatial.distance.cosine,
"L1": spatial.distance.cityblock,
"L2": spatial.distance.euclidean,
"Linf": spatial.distance.chebyshev,
}


distance_metrics: Dict[str, Callable[..., float]] = ProxyObject(obj_getter = _initialize_distance_dict)

MetricT = Literal["cosine", "L1", "L2", "Linf"]


def distances_from_embeddings(
query_embedding: List[float],
embeddings: List[List[float]],
distance_metric: Optional[MetricT] = "cosine",
) -> List[List]:
"""
Return the distances between a query embedding and a list of embeddings.
"""
return [
distance_metrics[distance_metric](query_embedding, embedding)
for embedding in embeddings
]


def indices_of_nearest_neighbors_from_distances(
distances: 'ndarray',
) -> 'ndarray':
"""
Return a list of indices of nearest neighbors from a list of distances.
"""
return np.argsort(distances)
2 changes: 1 addition & 1 deletion async_openai/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = '0.0.50rc0'
VERSION = '0.0.50rc1'
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
requirements.append('typing_extensions')

extras = {
'cache': ['kvdb'], # Adds caching support
'utils': ['numpy', 'scipy'] # Adds embedding utility support
}

args = {
Expand Down

0 comments on commit 25d69d3

Please sign in to comment.