Skip to content

Commit

Permalink
update multithd for cache
Browse files Browse the repository at this point in the history
  • Loading branch information
luyaxi committed Sep 18, 2024
1 parent 6cb4e1e commit 8da1593
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 64 deletions.
3 changes: 2 additions & 1 deletion codelinker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class RequestConfig(BaseModel):
save_completions: bool = False
# default to save in current working directory
save_completions_path: str = os.path.join(os.getcwd(),"cache","completions")

fileio_max_workers: int = 16

request: RequestConfig = RequestConfig()

class ExectionConfig(BaseModel):
Expand Down
148 changes: 87 additions & 61 deletions codelinker/request/objGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import jsonschema
import jsonschema.exceptions
import importlib
import aiofiles
import asyncio

from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Literal
from copy import deepcopy
from tenacity import AsyncRetrying, RetryError, stop_after_attempt
Expand All @@ -18,61 +17,87 @@
from ..config import CodeLinkerConfig
from ..utils import clip_text


class OBJGenerator:
def __init__(self, config: CodeLinkerConfig, logger: Logger):
self.config = config
self.logger = logger
self.chatcompletion_request_funcs = {}

loop = asyncio.get_event_loop()
loop.run_until_complete(self._load_cache_files())
self.thdpool = ThreadPoolExecutor(
max_workers=self.config.request.fileio_max_workers)

self._load_cache_files()
if self.config.request.save_completions:
os.makedirs(self.config.request.save_completions_path, exist_ok=True)

async def _load_cache_files(self):
os.makedirs(self.config.request.save_completions_path,
exist_ok=True)

def _load_cache_files(self):
if self.config.request.use_cache:
self.logger.warning("use_cache is enabled, loading completions from cache...")
self.logger.warning(
"use_cache is enabled, loading completions from cache...")
self.hash2files = {}
files = glob.glob(os.path.join(self.config.request.save_completions_path, "*.json"))
files = glob.glob(os.path.join(
self.config.request.save_completions_path, "*.json"))
files.sort(key=os.path.getmtime)

async def load_file(file):
async with aiofiles.open(file, 'r') as f:
data = json.loads(await f.read())
self.hash2files[hash(json.dumps(data["request"],sort_keys=True))] = file
tasks = [load_file(file) for file in files]
await asyncio.gather(*tasks)
self.logger.warning("Cache loaded and enabled, which may cause unexpected behavior.")

def load_file(file):
with open(file, 'r') as f:
data = json.load(f)
self.hash2files[hash(json.dumps(
data["request"], sort_keys=True))] = file

tasks = [self.thdpool.submit(load_file, file) for file in files]
for task in as_completed(tasks):
task.result()

self.logger.warning(
"Cache loaded and enabled, which may cause unexpected behavior.")

async def _load_cache_files_async(self, file):
def load_file(file):
with open(file, 'r') as f:
data = json.load(f)
return data

task = self.thdpool.submit(load_file, file)
return task.result()

async def _write_cache_files_async(self, file, data):
def write_file(file, data):
with open(file, 'w') as f:
json.dump(data, f)
task = self.thdpool.submit(write_file, file, data)
task.result()

async def _chatcompletion_request(self, *, request_lib: Literal["openai",] = None, **kwargs) -> dict:
if self.config.request.use_cache:
hash_ = hash(json.dumps(kwargs,sort_keys=True))
hash_ = hash(json.dumps(kwargs, sort_keys=True))
self.logger.debug(f"Request hash: {hash_}")
if hash_ in self.hash2files:
async with aiofiles.open(self.hash2files[hash_], 'r') as f:
data = json.loads(await f.read())
if data["request"] == kwargs:
self.logger.info(f"Cache hit file {self.hash2files[hash_]}, return cached completions.")
# remove cache to avoid duplicate
self.hash2files.pop(hash_)
return data["response"]
data = await self._load_cache_files_async(self.hash2files[hash_])
if data["request"] == kwargs:
self.logger.info(
f"Cache hit file {self.hash2files[hash_]}, return cached completions.")
# remove cache to avoid duplicate
self.hash2files.pop(hash_)
return data["response"]

request_lib = request_lib if request_lib is not None else self.config.request.default_request_lib
response = await self._get_chatcompletion_request_func(request_lib)(config=self.config,**kwargs)


response = await self._get_chatcompletion_request_func(request_lib)(config=self.config, **kwargs)

if self.config.request.save_completions:
async with aiofiles.open(os.path.join(self.config.request.save_completions_path,
datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")+f"{uuid.uuid4().hex}.json"
), 'w') as f:
data = json.dumps({
await self._write_cache_files_async(
file=os.path.join(self.config.request.save_completions_path,
datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f") +
f"{uuid.uuid4().hex}.json"
),
data={
"request": kwargs,
"response": response
},)
await f.write(data)
})

return response

def register_request_lib(self, request_type: str, request_func):
self.chatcompletion_request_funcs[request_type] = request_func

Expand All @@ -92,12 +117,11 @@ def _get_embedding_request_func(self, request_type: str):
module, 'embedding_request')
return self.chatcompletion_request_funcs[request_type]


async def chatcompletion(self,
*,
messages: list[dict],
schemas: None | StructureSchema | list[StructureSchema |
list[StructureSchema]] = None,
list[StructureSchema]] = None,
schema_validation: bool = None,
dynamic_json_fix: bool = None,
max_retry_times: int = None,
Expand All @@ -106,26 +130,28 @@ async def chatcompletion(self,
if kwargs[k] is None:
kwargs.pop(k)

request_format = kwargs.pop("request_format", self.config.request.format)
request_format = kwargs.pop(
"request_format", self.config.request.format)
if schemas is None:
request_format = "messages"

# filter messages with empty content
messages = list(filter(lambda x: len(x["content"]) > 0, messages))

max_retry_times = self.config.max_retry_times if max_retry_times is None else max_retry_times
schema_validation = self.config.request.schema_validation if schema_validation is None else schema_validation
dynamic_json_fix = self.config.request.dynamic_json_fix if dynamic_json_fix is None else dynamic_json_fix

if not dynamic_json_fix:
self.logger.debug("Warning: dynamic json fix is disabled, raise error if schema validation failed.")

self.logger.debug(
"Warning: dynamic json fix is disabled, raise error if schema validation failed.")

match request_format:
case "messages":
structuredRets = []
async for attempt in AsyncRetrying(stop=stop_after_attempt(max_retry_times), reraise=True):
with attempt:
response = await self._chatcompletion_request(messages=messages,**kwargs)
response = await self._chatcompletion_request(messages=messages, **kwargs)
for choice in response["choices"]:
structuredRets.append(StructuredRet(
name="response",
Expand Down Expand Up @@ -163,16 +189,16 @@ async def _chatcompletion_with_tool_call(
max_retry_times: int,
**kwargs,
) -> list[StructuredRet] | list[list[StructuredRet]]:

req_template = None

# construct the request template
if schemas is not None:
if "tools" in kwargs or "tool_choice" in kwargs:
raise ValueError(
"You can't provide schemas with tools/tool_choice!")
if isinstance(schemas, StructureSchema):

if isinstance(schemas, StructureSchema):
kwargs["tools"] = [{
"type": "function",
"function": {
Expand All @@ -185,13 +211,13 @@ async def _chatcompletion_with_tool_call(
"type": "function",
"function": {"name": schemas.name}
}

constructed_schema = {
"name": schemas.name,
"description": schemas.description,
"parameters": schemas.parameters
}

elif isinstance(schemas, list):
req_template = {
"type": "object",
Expand Down Expand Up @@ -246,7 +272,7 @@ async def _chatcompletion_with_tool_call(
async for attempt in AsyncRetrying(stop=stop_after_attempt(max_retry_times), reraise=True):
with attempt:
response = await self._chatcompletion_request(**kwargs)

for choice in response["choices"]:
structuredRets = []
for tool_call in choice["message"]["tool_calls"]:
Expand Down Expand Up @@ -285,21 +311,21 @@ async def _chatcompletion_with_tool_call(
raise e
return rets


async def dynamic_json_fixes(self, s, schema, messages: list = [], error_message: str = None) -> dict:

self.logger.warn(f'Schema Validation on string:\n{s}\nfor schema {schema["name"]} failed, trying to fix it...')
self.logger.warn(
f'Schema Validation on string:\n{s}\nfor schema {schema["name"]} failed, trying to fix it...')

ret = await self.chatcompletion(
messages=[
{
"role": "system",
"content": "Your task is to fix the string with schema errors. Remember to keep the target schema in mind. Avoid adding any information about this fix!"
},
{
"role": "user",
"content": clip_text(f"""Generate Content with Correct Schema.\n# Error String\n{s}\n# Error Message\n{error_message[-1024:]}""",max_tokens=self.config.execution.max_message_tokens,clip_end=True)[0]
}],
{
"role": "system",
"content": "Your task is to fix the string with schema errors. Remember to keep the target schema in mind. Avoid adding any information about this fix!"
},
{
"role": "user",
"content": clip_text(f"""Generate Content with Correct Schema.\n# Error String\n{s}\n# Error Message\n{error_message[-1024:]}""", max_tokens=self.config.execution.max_message_tokens, clip_end=True)[0]
}],
schemas=StructureSchema(**schema),
schema_validation=True,
dynamic_json_fix=False,
Expand All @@ -312,7 +338,7 @@ async def schema_valiation(
s: str,
schema: dict,
messages: list[dict],
dynamic_json_fix = None) -> dict:
dynamic_json_fix=None) -> dict:
# load string as instance
try:
d = json.loads(s)
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "codelinker"
version = "0.3.14"
version = "0.3.15"
description = "A tool to link the code with large language models."
authors = ["luyaxi <[email protected]>"]
readme = "README.md"
Expand All @@ -12,7 +12,6 @@ tenacity = ">=8.2.3,<9.0"
jsonschema = ">=4.21.1,<5.0"
toml = ">=0.10.2,<1.0"
tiktoken = ">=0.7.0,<1.0"
aiofiles = ">=24.1.0"


[build-system]
Expand Down

0 comments on commit 8da1593

Please sign in to comment.