Skip to content

Commit

Permalink
Add support for OpenRouter and MistralAI (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
AAClause committed Feb 27, 2024
2 parents e2ca1a9 + 0d11b02 commit 036c309
Show file tree
Hide file tree
Showing 7 changed files with 604 additions and 192 deletions.
210 changes: 131 additions & 79 deletions addon/globalPlugins/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import ui
from logHandler import log
from scriptHandler import script, getLastScriptRepeatCount
from . import apikeymanager
from . import configspec
from . import updatecheck
from .apikeymanager import APIKeyManager
from .consts import (
ADDON_DIR, DATA_DIR,
ADDON_DIR, BASE_URLs, DATA_DIR,
LIBS_DIR_PY,
TTS_MODELS, TTS_VOICES
)
Expand All @@ -33,16 +33,85 @@
ROOT_ADDON_DIR
).manifest

NO_AUTHENTICATION_KEY_PROVIDED_MSG = _("No authentication key provided. Please set it in the Preferences dialog.")
NO_AUTHENTICATION_KEY_PROVIDED_MSG = _("No API key provided for any provider, please provide at least one API key in the settings dialog")

conf = config.conf["OpenAI"]
api_key_manager = APIKeyManager(DATA_DIR)


class APIAccessDialog(wx.Dialog):

def __init__(
self,
parent,
title: str,
APIKeyManager: apikeymanager.APIKeyManager,
):
super(APIAccessDialog, self).__init__(parent, title=title)
self.APIKeyManager = APIKeyManager
self.provider_name = APIKeyManager.provider
self.InitUI()
self.CenterOnParent()
self.SetSize((500, 200))

def InitUI(self):
pnl = wx.Panel(self)
vbox = wx.BoxSizer(wx.VERTICAL)
fgs = wx.FlexGridSizer(3, 2, 9, 25) # 3 rows, 2 columns, vertical and horizontal gap

lblAPIKey = wx.StaticText(pnl, label=f"{self.provider_name} API Key:")
self.txtAPIKey = wx.TextCtrl(pnl)

lblOrgName = wx.StaticText(pnl, label="Organization name:")
self.txtOrgName = wx.TextCtrl(pnl)

lblOrgKey = wx.StaticText(pnl, label="Organization key:")
self.txtOrgKey = wx.TextCtrl(pnl)

# Adding Rows to the FlexGridSizer
fgs.AddMany(
[
lblAPIKey, (self.txtAPIKey, 1, wx.EXPAND),
lblOrgName, (self.txtOrgName, 1, wx.EXPAND),
lblOrgKey, (self.txtOrgKey, 1, wx.EXPAND),
])

# Configure an expanding column for text controls
fgs.AddGrowableCol(1, 1)

APIKey = self.APIKeyManager.get_api_key()
if APIKey:
self.txtAPIKey.SetValue(
APIKey
)
orgKey = self.APIKeyManager.get_organization_key()
orgName = self.APIKeyManager.get_organization_name()
if orgKey and orgName:
self.txtOrgName.SetValue(
orgName
)
self.txtOrgKey.SetValue(
orgKey
)

btnsizer = wx.StdDialogButtonSizer()
btnOK = wx.Button(pnl, wx.ID_OK)
btnOK.SetDefault()
btnsizer.AddButton(btnOK)
btnsizer.AddButton(wx.Button(pnl, wx.ID_CANCEL))
btnsizer.Realize()

# Layout sizers
vbox.Add(fgs, proportion=1, flag=wx.ALL|wx.EXPAND, border=10)
vbox.Add(btnsizer, flag=wx.ALIGN_CENTER|wx.TOP|wx.BOTTOM, border=10)
pnl.SetSizer(vbox)


class SettingsDlg(gui.settingsDialogs.SettingsPanel):

title = "Open AI"

def makeSettings(self, settingsSizer):

sHelper = gui.guiHelper.BoxSizerHelper(self, sizer=settingsSizer)

updateGroupLabel = _("Update")
Expand All @@ -69,50 +138,26 @@ def makeSettings(self, settingsSizer):

sHelper.addItem(updateSizer)

APIKey = api_key_manager.get_api_key()
if not APIKey: APIKey = ''
APIKeyOrg = api_key_manager.get_api_key(use_org=True)
org_name = ""
org_key = ""
if APIKeyOrg and ":=" in APIKeyOrg :
org_name, org_key = APIKeyOrg.split(":=")
self.APIKey = sHelper.addLabeledControl(
_("API Key:"),
wx.TextCtrl,
value=APIKey
)

orgGroupLabel = _("Organization")
orgSizer = wx.StaticBoxSizer(wx.VERTICAL, self, label=orgGroupLabel)
orgGroupBox = orgSizer.GetStaticBox()
orgGroup = gui.guiHelper.BoxSizerHelper(self, sizer=orgSizer)

self.use_org = orgGroup.addItem(
wx.CheckBox(
orgGroupBox,
label=_("Use or&ganization"))
)
self.use_org.SetValue(
conf["use_org"]
)
self.use_org.Bind(
wx.EVT_CHECKBOX,
self.onUseOrg
)

self.org_name = orgGroup.addLabeledControl(
_("Organization &name:"),
wx.TextCtrl,
value=org_name
)

self.org_key = orgGroup.addLabeledControl(
_("&Organization key:"),
wx.TextCtrl,
value=org_key
)
APIAccessGroupLabel = _("API Access Keys")
APIAccessSizer = wx.StaticBoxSizer(wx.HORIZONTAL, self, label=APIAccessGroupLabel)
APIAccessBox = APIAccessSizer.GetStaticBox()
APIAccessGroup = gui.guiHelper.BoxSizerHelper(self, sizer=APIAccessSizer)

for provider in apikeymanager.AVAILABLE_PROVIDERS:
item = APIAccessGroup.addItem(
wx.Button(
APIAccessBox,
label=_("%s API &keys...") % provider,
id=wx.ID_ANY,
name=provider
)
)
item.Bind(
wx.EVT_BUTTON,
self.onAPIKeys
)

sHelper.addItem(orgSizer)
sHelper.addItem(APIAccessSizer)

mainDialogGroupLabel = _("Main dialog")
mainDialogSizer = wx.StaticBoxSizer(wx.VERTICAL, self, label=mainDialogGroupLabel)
Expand Down Expand Up @@ -239,12 +284,25 @@ def makeSettings(self, settingsSizer):

sHelper.addItem(mainDialogSizer)

self.onUseOrg(None)
self.onResize(None)

def onUseOrg(self, evt):
self.org_name.Enable(self.use_org.GetValue())
self.org_key.Enable(self.use_org.GetValue())
def onAPIKeys(self, evt):
provider_name = evt.GetEventObject().GetName()
manager = apikeymanager.get(provider_name)
dlg = APIAccessDialog(
self,
"%s API Access Keys" % provider_name,
manager
)
if dlg.ShowModal() == wx.ID_OK:
manager.save_api_key(
dlg.txtAPIKey.GetValue().strip()
)
manager.save_api_key(
dlg.txtOrgKey.GetValue().strip(),
org=True,
org_name=dlg.txtOrgName.GetValue()
)

def onResize(self, evt):
self.maxWidth.Enable(self.resize.GetValue())
Expand All @@ -260,23 +318,6 @@ def onDefaultPrompt(self, evt):
def onSave(self):
conf["update"]["check"] = self.updateCheck.GetValue()
conf["update"]["channel"] = self.updateChannel.GetString(self.updateChannel.GetSelection())
api_key = self.APIKey.GetValue().strip()
api_key_manager.save_api_key(api_key)
api_key_org = self.org_key.GetValue().strip()
conf["use_org"] = self.use_org.GetValue()
org_name = self.org_name.GetValue().strip()
if conf["use_org"]:
if not api_key_org:
self.org_key.SetFocus()
return
if not org_name:
self.org_name.SetFocus()
return
api_key_manager.save_api_key(
api_key_org,
org=True,
org_name=org_name
)
conf["blockEscapeKey"] = self.blockEscape.GetValue()
conf["renewClient"] = True
conf["saveSystem"] = self.saveSystem.GetValue()
Expand All @@ -295,17 +336,24 @@ def onSave(self):
else:
conf["images"]["useCustomPrompt"] = False


class GlobalPlugin(globalPluginHandler.GlobalPlugin):

scriptCategory = "Open AI"

def __init__(self):
super().__init__()
APIKey = api_key_manager.get_api_key()
gui.settingsDialogs.NVDASettingsDialog.categoryClasses.append(SettingsDlg)
self.client = None
self.recordtThread = None
self.createMenu()
apikeymanager.load(DATA_DIR)
log.info(
"Open AI initialized. Version: %s. %d providers" % (
ADDON_INFO["version"],
len(apikeymanager._managers or [])
)
)

def createMenu(self):
self.submenu = wx.Menu()
Expand Down Expand Up @@ -409,20 +457,24 @@ def getClient(self):
conf["renewClient"] = False
if self.client:
return self.client
api_key = api_key_manager.get_api_key()
organization = api_key_manager.get_api_key(use_org=True)
if not api_key or not api_key.strip():
return None
if conf["use_org"]:
if not organization or not organization.strip():

# initialize the client with the first available provider, will be adjusted on the fly if needed
for provider in apikeymanager.AVAILABLE_PROVIDERS:
manager = apikeymanager.get(provider)
if not manager.ready:
continue
api_key = manager.get_api_key()
if not api_key or not api_key.strip():
return None
self.client = OpenAI(
organization=organization.split(":=")[1],
api_key=api_key
)
else:
self.client = OpenAI(api_key=api_key)
return self.client
organization = manager.get_api_key(use_org=True)
if organization and organization.count(":=") == 1:
self.client.organization = organization.split(":=")[1]
self.client.base_url = BASE_URLs[manager.provider]
return self.client
return None

def checkScreenCurtain(self):
from visionEnhancementProviders.screenCurtain import ScreenCurtainProvider
Expand Down
71 changes: 64 additions & 7 deletions addon/globalPlugins/openai/apikeymanager.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
import os

API_KEY_FILENAME = "OpenAI.key"
API_KEY_ORG_FILENAME = "OpenAI_org.key"
AVAILABLE_PROVIDERS = [
"OpenAI",
"OpenRouter",
"MistralAI"
]

_managers = {}

class APIKeyManager:

"""
Manage API key
"""

def __init__(self, data_dir):
def __init__(
self,
data_dir,
provider="OpenAI"
):
if provider not in AVAILABLE_PROVIDERS:
raise ValueError(f"Unknown provider: {provider}")
self.data_dir = data_dir
self.api_key_path = os.path.join(data_dir, API_KEY_FILENAME)
self.api_key_org_path = os.path.join(data_dir, API_KEY_ORG_FILENAME)
self.provider = provider
self.api_key_path = os.path.join(
data_dir,
f"{provider}.key"
)
self.api_key_org_path = os.path.join(
data_dir,
f"{provider}_org.key"
)
self.api_key = None
self.api_key_org = None
self.ensure_data_dir()
Expand All @@ -26,7 +44,7 @@ def _read_api_key_from_file(self, file_path):
with open(file_path, "r") as f:
return f.read().strip()
except FileNotFoundError:
return None
return ""

def get_api_key(self, use_org=False):
if use_org:
Expand All @@ -36,7 +54,21 @@ def get_api_key(self, use_org=False):

if self.api_key is None:
self.api_key = self._read_api_key_from_file(self.api_key_path)
return self.api_key or os.getenv("OPENAI_API_KEY")
return self.api_key or (
os.getenv("OPENAI_API_KEY" if self.provider == "OpenAI" else "OPENROUTER_API_KEY")
)

def get_organization_key(self):
organization = self.get_api_key(use_org=True)
if not organization or organization.count(":=") != 1:
return None
return organization.split(":=")[1]

def get_organization_name(self):
organization = self.get_api_key(use_org=True)
if not organization or organization.count(":") != 1:
return None
return organization.split(":")[0]

def save_api_key(self, key, org=False, org_name=None):
file_path = self.api_key_org_path if org else self.api_key_path
Expand All @@ -49,3 +81,28 @@ def save_api_key(self, key, org=False, org_name=None):
self.api_key_org = f"{org_name}:={key}"
else:
self.api_key = key

def ready(self):
return self.get_api_key() is not None


def load(
data_dir: str
):
"""
Initialize API key manager for all providers
"""
global _managers
for provider in AVAILABLE_PROVIDERS:
_managers[provider] = APIKeyManager(data_dir, provider)


def get(
provider_name: str
) -> APIKeyManager:
"""
Get API key manager for provider_name
"""
if provider_name not in AVAILABLE_PROVIDERS:
raise ValueError(f"Unknown provider: {provider_name}. Available: {AVAILABLE_PROVIDERS}")
return _managers[provider_name]
1 change: 0 additions & 1 deletion addon/globalPlugins/openai/configspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"check": "boolean(default=True)",
"channel": "string(default='stable')"
},
"use_org": "boolean(default=False)",
"model": f"string(default={DEFAULT_MODEL.name})",
"topP": f"integer(min={TOP_P_MIN}, max={TOP_P_MAX}, default={DEFAULT_TOP_P})",
"n": f"integer(min={N_MIN}, max={N_MAX}, default={DEFAULT_N})",
Expand Down
Loading

0 comments on commit 036c309

Please sign in to comment.