Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing personalization with Gorilla #56

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ dist
.vscode
.idea
.editorconfig
.env
61 changes: 60 additions & 1 deletion go_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,32 @@
import fcntl
import platform
import requests
from openai import OpenAI
import json
import subprocess
import argparse
import termios
import urllib.parse
import sys
from halo import Halo
import go_questionary
from personalization.personalize import GorillaPersonalizer
from personalization.setup import PersonalizationSetup

__version__ = "0.0.11" # current version

__version__ = "0.0.12" # current version
SERVER_URL = "https://cli.gorilla-llm.com"
UPDATE_CHECK_FILE = os.path.expanduser("~/.gorilla-cli-last-update-check")
USERID_FILE = os.path.expanduser("~/.gorilla-cli-userid")
HISTORY_FILE = os.path.expanduser("~/.gorilla_cli_history")
CONFIG_FILE = os.path.expanduser("~/.gorilla-cli-config.json")
ISSUE_URL = f"https://github.com/gorilla-llm/gorilla-cli/issues/new"
GORILLA_EMOJI = "🦍 " if go_questionary.try_encode_gorilla() else ""
HISTORY_LENGTH = 10
WELCOME_TEXT = f"""===***===
{GORILLA_EMOJI}Welcome to Gorilla-CLI! Enhance your Command Line with the power of LLMs!


Simply use `gorilla <your desired operation>` and Gorilla will do the rest. For instance:
gorilla generate 100 random characters into a file called test.txt
gorilla get the image ids of all pods running in all namespaces in kubernetes
Expand All @@ -54,6 +61,13 @@
def generate_random_uid():
return str(uuid.uuid4())

def load_config():
arthbohra marked this conversation as resolved.
Show resolved Hide resolved
# Load the user's configuration file and perform any necessary checks
if os.path.isfile(CONFIG_FILE):
with open(CONFIG_FILE, "r") as config_file:
config_json = json.load(config_file)
return config_json

def get_git_email():
return subprocess.check_output(["git", "config", "--global", "user.email"]).decode("utf-8").strip()

Expand Down Expand Up @@ -122,8 +136,32 @@ def check_for_updates():
except Exception as e:
print("Unable to write update check file:", e)

def setup_config_file(user_id):
config_json = { "user_id": user_id }
with open(CONFIG_FILE, "w") as config_file:
json.dump(config_json, config_file)

def get_user_id():
if os.path.isfile(CONFIG_FILE):
with open(CONFIG_FILE, "r") as config_file:
try:
config_json = json.load(config_file)
if "user_id" not in config_json:
user_id = get_user_id_deprecated()
setup_config_file(user_id)
return get_user_id_deprecated()
else:
return config_json["user_id"]
except:
user_id = get_user_id_deprecated()
setup_config_file(user_id)
return get_user_id_deprecated
else:
user_id = get_user_id_deprecated()
setup_config_file(user_id)
return get_user_id_deprecated()

def get_user_id_deprecated():
# Unique user identifier for authentication and load balancing
# Gorilla-CLI is hosted by UC Berkeley Sky lab for FREE as a
# research prototype. Please don't spam the system or use it
Expand Down Expand Up @@ -184,6 +222,9 @@ def append_string_to_file_if_missing(file_path, target_string):
file.write(target_string)





def main():
def execute_command(cmd):
cmd = format_command(cmd)
Expand Down Expand Up @@ -220,6 +261,13 @@ def get_history_commands(history_file):
user_id = get_user_id()
system_info = get_system_info()

personalization_setup = PersonalizationSetup()
# personalization_setup.request_personalization()
personalization = personalization_setup.personalization
open_ai_key = personalization_setup.open_ai_key

personalized_history = None


# Parse command-line arguments
parser = argparse.ArgumentParser(description="Gorilla CLI Help Doc")
Expand All @@ -231,13 +279,24 @@ def get_history_commands(history_file):
# Generate a unique interaction ID
interaction_id = str(uuid.uuid4())



commands = []
if args.history:
commands = get_history_commands(HISTORY_FILE)

if personalization:
personalizer = GorillaPersonalizer(open_ai_key)
personalized_history = personalizer.personalize(user_input, commands)
print (personalized_history)


else:
with Halo(text=f"{GORILLA_EMOJI}Loading", spinner="dots"):
try:
data_json = {
"user_id": user_id,
#"synthesized_history": personalized_history,
"user_input": user_input,
"interaction_id": interaction_id,
"system_info": system_info
Expand Down
106 changes: 106 additions & 0 deletions personalization/personalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
from presidio_analyzer import AnalyzerEngine, PatternRecognizer
from presidio_anonymizer import AnonymizerEngine
from presidio_anonymizer.entities import OperatorConfig
import json
from pprint import pprint
from openai import OpenAI
from personalization.prompts import get_system_prompt, get_user_prompt

class GorillaPersonalizer:
"""A class to personalize the user's bash history and query to provide the model with relevant context.

Attributes:
client: The OpenAI client.

"""

def __init__(self, open_ai_key):
""" Initializes the GorillaPersonalizer class."""
self.client = OpenAI(api_key=open_ai_key)

def get_bash_history(self):
"""
Retrieves the user's bash history. (Last 10 commands)
"""
history_file = os.path.expanduser("~/.bash_history")
prev_operations = ""
try:
with open(history_file, "r") as file:
history = file.readlines()
except FileNotFoundError:
return "No bash history was found."
return history[:-10]

def anonymize_bash_history(self, operations):
"""
Uses Microsoft's Presidio to anonymize the user's bash history.

Args:
operations: The user's bash history.

"""
analyzer = AnalyzerEngine()
analyzer_results = analyzer.analyze(text=operations, language="en")
anonymizer = AnonymizerEngine()
anonymized_results = anonymizer.anonymize(
text=operations, analyzer_results=analyzer_results
)
return anonymized_results.text

def remove_duplicates(self, operations: list[str]):
"""Removes duplicates from the user's bash history

Args:
operations: The user's bash history.

"""
return list(set(operations))

def stringify_bash_history(self, operations: list[str]):
"""Stringifies the user's bash history.

Args:
operations: The user's bash history.

"""
return "\n".join(operations)

def synthesize_bash_history(self, desired_operation, gorilla_history, history):
"""Uses OpenAI api to synthesize the user's bash, gorilla history.
It synthesizes this history with the current operation in mind, with the goal of providing the model with relevant context.

Args:
client: The OpenAI client.
desired_operation: The operation the user wants to perform.
gorilla_history: The user's previous operations with the Gorilla
history: The user's bash history.

"""
system_prompt = get_system_prompt()
user_prompt = get_user_prompt(history, gorilla_history, desired_operation)
response = self.client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
return response.choices[0].message.content

def personalize(self, query, gorilla_history, pi_removal=True):
"""Personalizes the user's bash history and query to provide the model with relevant context.

Args:
query: The operation the user wants to perform.
gorila_history: The user's previous operations with the Gorilla
open_ai_key: The OpenAI API key.
pi_removal: Whether to remove personally identifiable information from the user's bash history.

"""

history = self.stringify_bash_history(self.remove_duplicates(self.get_bash_history()))
if pi_removal:
history = self.anonymize_bash_history(history)
summary = self.synthesize_bash_history(query, gorilla_history, history)
return summary
26 changes: 26 additions & 0 deletions personalization/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@


def get_system_prompt():
return f"""
You are an assistant for a developer who wants to find the right API call for a specific task.
The developer has bash history that contains the command they used to perform a task.
Synthesize their bash history to provide the API call prediction model with extra context about the task.
For reference, the API call prediction model, called Gorilla, is trained on a large dataset of API calls and their associated tasks.
You may see the developer's previous operations with the API calling tool in their bash history.
Use the previous bash history as well as their query to provide the model with a short paragraph of possible relevant context.
There is a chance that their query has nothing to do with the bash history, so in that case, return 'No relevant context found'.
"""

def get_user_prompt(history, gorilla_history, desired_operation):
return f"""
The user's bash history is:
{history}

The user's previous operations with the API calling tool are:
{gorilla_history}

The query of the user is:
{desired_operation}

Us this information to provide the model with a short paragraph of possible relevant context.
"""
Loading