Skip to content

Commit

Permalink
Update llm_processor.py
Browse files Browse the repository at this point in the history
Groq API support added, +4 fast and free models added:
- llama3-8b-8192
- llama3-70b-8192
- mixtral-8x7b-32768
- gemma-7b-it

 (rate limit and max tokens settings in llm_processor.py to avoid error floading, barely tested :D ) 

configuration into config.yaml:

```
  # Groq API configuration
  groq_api_url: "https://api.groq.com/openai/v1/chat/completions"
  groq_api_key: ""  # Enter your Groq API key here
  groq_model: "llama3-70b-8192"  # Model to use from Groq
```
  • Loading branch information
fabriziosalmi authored May 31, 2024
1 parent bad8cf7 commit 4ea66c5
Showing 1 changed file with 123 additions and 35 deletions.
158 changes: 123 additions & 35 deletions llm_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import argparse
import yaml
import time
from pathlib import Path
from datetime import datetime
from requests.adapters import HTTPAdapter
Expand All @@ -13,6 +14,16 @@
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Ensure output and rewritten folders exist
output_folder = Path('output')
rewritten_folder = Path('rewritten')

output_folder.mkdir(parents=True, exist_ok=True)
rewritten_folder.mkdir(parents=True, exist_ok=True)

# Maximum context length for Groq API (this may need adjustment based on the actual limit)
MAX_TOKENS = 32768

def requests_retry_session(retries=3, backoff_factor=0.3, status_forcelist=(500, 502, 504), session=None):
session = session or requests.Session()
retry = Retry(
Expand All @@ -27,6 +38,23 @@ def requests_retry_session(retries=3, backoff_factor=0.3, status_forcelist=(500,
session.mount('https://', adapter)
return session

def estimate_token_count(text):
# Simple estimation: one token per 4 characters
return len(text) / 4

def truncate_content(content, max_tokens):
tokens = content.split()
truncated_content = []
current_tokens = 0

for token in tokens:
current_tokens += len(token) / 4
if current_tokens > max_tokens:
break
truncated_content.append(token)

return ' '.join(truncated_content)

def call_openai_api(api_url, combined_content, model, api_key):
client = OpenAI(api_key=api_key)
try:
Expand All @@ -42,12 +70,54 @@ def call_openai_api(api_url, combined_content, model, api_key):
logging.error(f"OpenAI API request failed: {e}")
return None

def call_groq_api(api_url, combined_content, model, api_key):
data = json.dumps({
"model": model,
"messages": [{"role": "user", "content": combined_content}],
"stream": False
})
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
logging.debug(f"Groq API request data: {data}")
try:
response = requests_retry_session().post(api_url, data=data, headers=headers)
response.raise_for_status()
try:
response_json = response.json()
logging.debug(f"Groq API response: {response_json}")
return response_json['choices'][0]['message']['content']
except json.JSONDecodeError as e:
logging.error(f"Failed to parse JSON response from Groq API: {e}")
logging.error(f"Response content: {response.text}")
return None
except requests.RequestException as e:
logging.error(f"Groq API request failed: {e}")
if response is not None:
logging.error(f"Groq API response content: {response.text}")
if 'rate_limit_exceeded' in response.text:
retry_after = parse_retry_after(response.json())
logging.info(f"Rate limit exceeded. Retrying after {retry_after} seconds.")
time.sleep(retry_after)
return call_groq_api(api_url, combined_content, model, api_key)
return None

def parse_retry_after(response_json):
try:
message = response_json['error']['message']
retry_after = float(re.search(r"try again in (\d+\.?\d*)s", message).group(1))
return retry_after
except (KeyError, AttributeError):
return 60 # Default retry after 60 seconds if parsing fails

def call_ollama_api(api_url, combined_content, model):
data = json.dumps({
"model": model,
"messages": [{"role": "user", "content": combined_content}],
"stream": False
})
logging.debug(f"Ollama API request data: {data}")
try:
response = requests_retry_session().post(api_url, data=data, headers={'Content-Type': 'application/json'})
response.raise_for_status()
Expand All @@ -61,6 +131,8 @@ def call_ollama_api(api_url, combined_content, model):
return None
except requests.RequestException as e:
logging.error(f"Ollama API request failed: {e}")
if response is not None:
logging.error(f"Ollama API response content: {response.text}")
return None

def ensure_proper_punctuation(text):
Expand Down Expand Up @@ -89,8 +161,14 @@ def process_json_file(filepath, api_url, model, api_key, content_prefix, rewritt
logging.info(f"Processing {filepath} - combined content prepared.")
logging.debug(f"Combined content: {combined_content}")

if estimate_token_count(combined_content) > MAX_TOKENS:
logging.info(f"Truncating content to fit within {MAX_TOKENS} tokens.")
combined_content = truncate_content(combined_content, MAX_TOKENS)

if api_type == "openai":
rewritten_content = call_openai_api(api_url, combined_content, model, api_key)
elif api_type == "groq":
rewritten_content = call_groq_api(api_url, combined_content, model, api_key)
else:
rewritten_content = call_ollama_api(api_url, combined_content, model)

Expand Down Expand Up @@ -127,55 +205,65 @@ def process_json_file(filepath, api_url, model, api_key, content_prefix, rewritt
logging.debug(f"Rewritten content: {rewritten_content}")

def validate_config(api_config):
if 'openai_api_url' in api_config and 'ollama_api_url' in api_config:
raise ValueError("Both OpenAI and Ollama API configurations are set. Please configure only one.")
if 'openai_api_url' not in api_config and 'ollama_api_url' not in api_config:
raise ValueError("Neither OpenAI nor Ollama API configuration is set. Please configure one.")
if 'openai_api_url' in api_config and ('openai_api_key' not in api_config or 'openai_model' not in api_config):
raise ValueError("OpenAI API configuration is incomplete. Please set the API URL, API key, and model.")
if 'ollama_api_url' in api_config and 'ollama_model' not in api_config:
raise ValueError("Ollama API configuration is incomplete. Please set the API URL and model.")

def main():
parser = argparse.ArgumentParser(description='Process JSON files and call LLM API.')
parser.add_argument('--config', type=str, help='Path to the configuration file.', default='config.yaml')

args = parser.parse_args()

config_path = args.config if args.config else 'config.yaml'
api_keys = ['openai_api_url', 'groq_api_url', 'ollama_api_url']
active_apis = [key for key in api_keys if key in api_config]

if len(active_apis) != 1:
raise ValueError("Exactly one API configuration must be set in the config file.")

if 'openai_api_url' in api_config:
if 'openai_api_key' not in api_config or 'openai_model' not in api_config:
raise ValueError("OpenAI API configuration is incomplete. Please set the API URL, API key, and model.")
elif 'groq_api_url' in api_config:
if 'groq_api_key' not in api_config or 'groq_model' not in api_config:
raise ValueError("Groq API configuration is incomplete. Please set the API URL, API key, and model.")
elif 'ollama_api_url' in api_config:
if 'ollama_model' not in api_config:
raise ValueError("Ollama API configuration is incomplete. Please set the API URL and model.")

with open(config_path, 'r') as config_file:
config = yaml.safe_load(config_file)
def main(config_path):
try:
with open(config_path, 'r', encoding='utf-8') as file:
config = yaml.safe_load(file)
except (yaml.YAMLError, IOError) as e:
logging.error(f"Error reading config file {config_path}: {e}")
return

api_config = config['api_config']
folders = config['folders']
content_prefix = config['content_prefix']
api_config = config.get('api_config', {})
folder_config = config.get('folders', {})
content_prefix = config.get('content_prefix', "")

validate_config(api_config)

rewritten_folder = Path(folders['rewritten_folder'])
rewritten_folder.mkdir(parents=True, exist_ok=True)

output_folder = Path(folders['output_folder'])
json_files = list(output_folder.glob('*.json'))
if not json_files:
logging.info("No JSON files found in the output folder.")
return

if 'openai_api_url' in api_config:
api_url = api_config['openai_api_url']
model = api_config['openai_model']
api_key = api_config['openai_api_key']
api_type = 'openai'
elif 'groq_api_url' in api_config:
api_url = api_config['groq_api_url']
model = api_config['groq_model']
api_key = api_config['groq_api_key']
api_type = 'groq'
else:
api_url = api_config['ollama_api_url']
model = api_config['ollama_model']
api_key = None
api_key = None # Ollama does not need an API key
api_type = 'ollama'

for filepath in json_files:
logging.info(f"Processing file: {filepath}")
process_json_file(filepath, api_url, model, api_key, content_prefix, rewritten_folder, api_type)
output_folder = folder_config.get('output_folder', 'output')
rewritten_folder = folder_config.get('rewritten_folder', 'rewritten')

Path(rewritten_folder).mkdir(parents=True, exist_ok=True)

json_files = Path(output_folder).glob('*.json')
for json_file in json_files:
process_json_file(json_file, api_url, model, api_key, content_prefix, rewritten_folder, api_type)

if __name__ == "__main__":
main()
parser = argparse.ArgumentParser(description='Process JSON files with LLM API')
parser.add_argument('--config', type=str, default='config.yaml', help='Path to the configuration YAML file (default: config.yaml in current directory)')
args = parser.parse_args()

config_path = args.config if args.config else 'config.yaml'
main(config_path)

0 comments on commit 4ea66c5

Please sign in to comment.