diff --git a/src/BingService.py b/src/BingService.py index 8405873..22764c7 100644 --- a/src/BingService.py +++ b/src/BingService.py @@ -5,7 +5,7 @@ import requests import yaml -from Util import setup_logger, get_project_root +from Util import setup_logger, get_project_root, storage_cached from text_extract.html.beautiful_soup import BeautifulSoupSvc from text_extract.html.trafilatura import TrafilaturaSvc @@ -21,6 +21,7 @@ def __init__(self, config): elif extract_svc == 'beautifulsoup': self.txt_extract_svc = BeautifulSoupSvc() + @storage_cached('bing_search_website', 'query') def call_bing_search_api(self, query: str) -> pd.DataFrame: logger.info("BingService.call_bing_search_api. query: " + query) subscription_key = self.config.get('bing_search').get('subscription_key') @@ -81,6 +82,7 @@ def call_one_url(self, website_tuple): logger.info(f" receive sentences: {len(sentences)}") return sentences, name, url, url_id, snippet + @storage_cached('bing_search_website_content', 'website_df') def call_urls_and_extract_sentences_concurrent(self, website_df): logger.info(f"BingService.call_urls_and_extract_sentences_async. website_df.shape: {website_df.shape}") with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: diff --git a/src/LLMService.py b/src/LLMService.py index e20c2d1..19787bb 100644 --- a/src/LLMService.py +++ b/src/LLMService.py @@ -6,7 +6,7 @@ import pandas as pd import yaml -from Util import setup_logger, get_project_root +from Util import setup_logger, get_project_root, storage_cached logger = setup_logger('LLMService') @@ -103,6 +103,7 @@ def __init__(self, config): raise Exception("OpenAI API key is not set.") openai.api_key = open_api_key + @storage_cached('openai', 'prompt') def call_api(self, prompt: str): openai_api_config = self.config.get('openai_api') model = openai_api_config.get('model') @@ -143,6 +144,7 @@ def __init__(self, config): openai.api_key = goose_api_key openai.api_base = config.get('goose_ai_api').get('api_base') + @storage_cached('gooseai', 'prompt') def call_api(self, prompt: str): logger.info(f"GooseAIService.call_openai_api. len(prompt): {len(prompt)}") goose_api_config = self.config.get('goose_ai_api') diff --git a/src/SearchGPTService.py b/src/SearchGPTService.py index 867a3cc..88d3e12 100644 --- a/src/SearchGPTService.py +++ b/src/SearchGPTService.py @@ -9,7 +9,7 @@ from FrontendService import FrontendService from LLMService import LLMServiceFactory from SemanticSearchService import BatchOpenAISemanticSearchService -from Util import setup_logger, post_process_gpt_input_text_df, check_result_cache_exists, load_result_from_cache, save_result_cache, check_max_number_of_cache, get_project_root +from Util import setup_logger, post_process_gpt_input_text_df, get_project_root, storage_cached from text_extract.doc import support_doc_type, doc_extract_svc_map from text_extract.doc.abc_doc_extract import AbstractDocExtractSvc @@ -60,23 +60,9 @@ def _prompt(self, search_text, text_df, cache_path=None): gpt_input_text_df = semantic_search_service.search_related_source(text_df, search_text) gpt_input_text_df = post_process_gpt_input_text_df(gpt_input_text_df, self.config.get('openai_api').get('prompt').get('prompt_length_limit')) - llm_service_provider = self.config.get('llm_service').get('provider') - # check if llm result is cached and load if exists - if self.config.get('cache').get('is_enable_cache') and check_result_cache_exists(cache_path, search_text, llm_service_provider): - logger.info(f"SemanticSearchService.load_result_from_cache. search_text: {search_text}, cache_path: {cache_path}") - cache = load_result_from_cache(cache_path, search_text, llm_service_provider) - prompt, response_text = cache['prompt'], cache['response_text'] - else: - llm_service = LLMServiceFactory.create_llm_service(self.config) - prompt = llm_service.get_prompt_v3(search_text, gpt_input_text_df) - response_text = llm_service.call_api(prompt) - - llm_config = self.config.get(f'{llm_service_provider}_api').copy() - llm_config.pop('api_key') # delete api_key to avoid saving it to .cache - save_result_cache(cache_path, search_text, llm_service_provider, prompt=prompt, response_text=response_text, config=llm_config) - - # check whether the number of cache exceeds the limit - check_max_number_of_cache(cache_path, self.config.get('cache').get('max_number_of_cache')) + llm_service = LLMServiceFactory.create_llm_service(self.config) + prompt = llm_service.get_prompt_v3(search_text, gpt_input_text_df) + response_text = llm_service.call_api(prompt=prompt) frontend_service = FrontendService(self.config, response_text, gpt_input_text_df) source_text, data_json = frontend_service.get_data_json(response_text, gpt_input_text_df) @@ -94,23 +80,14 @@ def _prompt(self, search_text, text_df, cache_path=None): def _extract_bing_text_df(self, search_text, cache_path): # BingSearch using search_text - # check if bing search result is cached and load if exists bing_text_df = None if not self.config['search_option']['is_use_source'] or not self.config['search_option']['is_enable_bing_search']: return bing_text_df - if self.config.get('cache').get('is_enable_cache') and check_result_cache_exists(cache_path, search_text, 'bing_search'): - logger.info(f"BingService.load_result_from_cache. search_text: {search_text}, cache_path: {cache_path}") - cache = load_result_from_cache(cache_path, search_text, 'bing_search') - bing_text_df = cache['bing_text_df'] - else: - bing_service = BingService(self.config) - website_df = bing_service.call_bing_search_api(search_text) - bing_text_df = bing_service.call_urls_and_extract_sentences_concurrent(website_df) - - bing_search_config = self.config.get('bing_search').copy() - bing_search_config.pop('subscription_key') # delete api_key from config to avoid saving it to .cache - save_result_cache(cache_path, search_text, 'bing_search', bing_text_df=bing_text_df, config=bing_search_config) + bing_service = BingService(self.config) + website_df = bing_service.call_bing_search_api(query=search_text) + bing_text_df = bing_service.call_urls_and_extract_sentences_concurrent(website_df=website_df) + return bing_text_df def _extract_doc_text_df(self, bing_text_df): @@ -143,6 +120,7 @@ def _extract_doc_text_df(self, bing_text_df): doc_text_df = pd.DataFrame(doc_sentence_list) return doc_text_df + @storage_cached('web', 'search_text') def query_and_get_answer(self, search_text): cache_path = Path(self.config.get('cache').get('path')) # TODO: strategy pattern to support different text sources (e.g. PDF later) diff --git a/src/Util.py b/src/Util.py index ee8c259..19e9890 100644 --- a/src/Util.py +++ b/src/Util.py @@ -2,7 +2,8 @@ import os import pickle import re -import shutil +from copy import deepcopy +from functools import wraps from hashlib import md5 from pathlib import Path @@ -37,34 +38,31 @@ def post_process_gpt_input_text_df(gpt_input_text_df, prompt_length_limit): return gpt_input_text_df -def save_result_cache(path: Path, search_text: str, cache_type: str = 'bing_search', **kwargs): - cache_dir = path / md5(search_text.encode()).hexdigest() - +def save_result_cache(path: Path, hash: str, type: str, **kwargs): + cache_dir = path / type os.makedirs(cache_dir, exist_ok=True) - path = Path(cache_dir, f'{cache_type}.pickle') + path = Path(cache_dir, f'{hash}.pickle') with open(path, 'wb') as f: pickle.dump(kwargs, f) -def load_result_from_cache(path: Path, search_text: str, cache_type: str = 'bing_search'): - path = path / f'{md5(search_text.encode()).hexdigest()}' / f'{cache_type}.pickle' +def load_result_from_cache(path: Path, hash: str, type: str): + path = path / type / f'{hash}.pickle' with open(path, 'rb') as f: return pickle.load(f) -def check_result_cache_exists(path: Path, search_text: str, cache_type: str = 'bing_search') -> bool: - path = path / f'{md5(search_text.encode()).hexdigest()}' / f'{cache_type}.pickle' - if os.path.exists(path): - return True - else: - return False +def check_result_cache_exists(path: Path, hash: str, type: str) -> bool: + path = path / type / f'{hash}.pickle' + return True if os.path.exists(path) else False -def check_max_number_of_cache(path: Path, max_number_of_cache: int = 10): - if len(os.listdir(path)) >= max_number_of_cache: +def check_max_number_of_cache(path: Path, type: str, max_number_of_cache: int = 10): + path = path / type + if len(os.listdir(path)) > max_number_of_cache: ctime_list = [(os.path.getctime(path / file), file) for file in os.listdir(path)] oldest_file = sorted(ctime_list)[0][1] - shutil.rmtree(path / oldest_file) + os.remove(path / oldest_file) def split_sentences_from_paragraph(text): @@ -72,6 +70,52 @@ def split_sentences_from_paragraph(text): return sentences +def remove_api_keys(d): + key_to_remove = ['api_key', 'subscription_key'] + temp_key_list = [] + for key, value in d.items(): + if key in key_to_remove: + temp_key_list += [key] + if isinstance(value, dict): + remove_api_keys(value) + + for key in temp_key_list: + d.pop(key) + return d + + +def storage_cached(cache_type: str, cache_hash_key_name: str): + def storage_cache_decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + assert getattr(args[0], 'config'), 'storage_cached is only applicable to class method with config attribute' + assert cache_hash_key_name in kwargs, f'Target method does not have {cache_hash_key_name} keyword argument' + + config = getattr(args[0], 'config') + if config.get('cache').get('is_enable').get(cache_type): + hash_key = str(kwargs[cache_hash_key_name]) + + cache_path = Path(get_project_root(), config.get('cache').get('path')) + cache_hash = md5(str(config).encode() + hash_key.encode()).hexdigest() + + if check_result_cache_exists(cache_path, cache_hash, cache_type): + result = load_result_from_cache(cache_path, cache_hash, cache_type)['result'] + else: + result = func(*args, **kwargs) + config_for_cache = deepcopy(config) + config_for_cache = remove_api_keys(config_for_cache) # remove api keys + save_result_cache(cache_path, cache_hash, cache_type, result=result, config=config_for_cache) + + check_max_number_of_cache(cache_path, cache_type, config.get('cache').get('max_number_of_cache')) + else: + result = func(*args, **kwargs) + + return result + + return wrapper + + return storage_cache_decorator + if __name__ == '__main__': text = "There are many things you can do to learn how to run faster, Mr. Wan, such as incorporating speed workouts into your running schedule, running hills, counting your strides, and adjusting your running form. Lean forward when you run and push off firmly with each foot. Pump your arms actively and keep your elbows bent at a 90-degree angle. Try to run every day, and gradually increase the distance you run for long-distance runs. Make sure you rest at least one day per week to allow your body to recover. Avoid running with excess gear that could slow you down." sentences = split_sentences_from_paragraph(text) diff --git a/src/config/config.yaml b/src/config/config.yaml index e71322f..da2a1b4 100644 --- a/src/config/config.yaml +++ b/src/config/config.yaml @@ -29,6 +29,11 @@ goose_ai_api: model: gpt-neo-20b max_tokens: 100 cache: # .cache result for efficiency and consistency - is_enable_cache: false + is_enable: + web: false + bing_search_website: false + bing_search_website_content: false + openai: false + gooseai: false path: .cache - max_number_of_cache: 0 + max_number_of_cache: 50