diff --git a/src/resin/chat_engine/chat_engine.py b/src/resin/chat_engine/chat_engine.py index fa1aee48..f0282322 100644 --- a/src/resin/chat_engine/chat_engine.py +++ b/src/resin/chat_engine/chat_engine.py @@ -20,7 +20,8 @@ DEFAULT_SYSTEM_PROMPT = """Use the following pieces of context to answer the user question at the next messages. This context retrieved from a knowledge database and you should use only the facts from the context to answer. Always remember to include the source to the documents you used from their 'source' field in the format 'Source: $SOURCE_HERE'. If you don't know the answer, just say that you don't know, don't try to make up an answer, use the context. -Don't address the context directly, but use it to answer the user question like it's your own knowledge.""" # noqa +Don't address the context directly, but use it to answer the user question like it's your own knowledge. +""" # noqa class BaseChatEngine(ABC, ConfigurableMixin): diff --git a/src/resin/chat_engine/query_generator/function_calling.py b/src/resin/chat_engine/query_generator/function_calling.py index 19c8d93c..b323774c 100644 --- a/src/resin/chat_engine/query_generator/function_calling.py +++ b/src/resin/chat_engine/query_generator/function_calling.py @@ -9,7 +9,8 @@ from resin.models.data_models import Messages, Query DEFAULT_SYSTEM_PROMPT = """Your task is to formulate search queries for a search engine, to assist in responding to the user's question. -You should break down complex questions into sub-queries if needed.""" # noqa: E501 +You should break down complex questions into sub-queries if needed. +""" # noqa: E501 DEFAULT_FUNCTION_DESCRIPTION = """Query search engine for relevant information""" diff --git a/src/resin/knowledge_base/knowledge_base.py b/src/resin/knowledge_base/knowledge_base.py index 9cf14a01..13a2f76c 100644 --- a/src/resin/knowledge_base/knowledge_base.py +++ b/src/resin/knowledge_base/knowledge_base.py @@ -152,7 +152,7 @@ def __init__(self, # Normally, index creation params are passed directly to the `.create_resin_index()` method. # noqa: E501 # However, when KnowledgeBase is initialized from a config file, these params # would be set by the `KnowledgeBase.from_config()` constructor. - self._index_params: Optional[Dict[str, Any]] = None + self._index_params: Dict[str, Any] = {} # The index object is initialized lazily, when the user calls `connect()` or # `create_resin_index()` @@ -306,7 +306,7 @@ def create_resin_index(self, ) # create index - index_params = index_params or self._index_params or {} + index_params = index_params or self._index_params try: create_index(name=self.index_name, dimension=dimension, diff --git a/src/resin/llm/openai.py b/src/resin/llm/openai.py index af7ca672..9e43fb6e 100644 --- a/src/resin/llm/openai.py +++ b/src/resin/llm/openai.py @@ -24,12 +24,10 @@ def __init__(self, ): super().__init__(model_name, model_params=model_params) - self.available_models = [k["id"] for k in openai.Model.list().data] - if model_name not in self.available_models: - raise ValueError( - f"Model {model_name} not found. " - f" Available models: {self.available_models}" - ) + + @property + def available_models(self): + return [k["id"] for k in openai.Model.list().data] @retry( wait=wait_random_exponential(min=1, max=10), diff --git a/tests/system/llm/test_openai.py b/tests/system/llm/test_openai.py index 1caad24c..ef5ad677 100644 --- a/tests/system/llm/test_openai.py +++ b/tests/system/llm/test_openai.py @@ -153,11 +153,6 @@ def test_max_tokens(openai_llm, messages): assert isinstance(response, ChatResponse) assert len(response.choices[0].message.content.split()) <= max_tokens - @staticmethod - def test_invalid_model_name(): - with pytest.raises(ValueError, match="Model invalid_model_name not found."): - OpenAILLM(model_name="invalid_model_name") - @staticmethod def test_missing_messages(openai_llm): with pytest.raises(InvalidRequestError): diff --git a/tests/system/utils/__init__.py b/tests/system/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/system/utils/test_config.py b/tests/system/utils/test_config.py new file mode 100644 index 00000000..84dbf1fe --- /dev/null +++ b/tests/system/utils/test_config.py @@ -0,0 +1,53 @@ +import os + +import pytest +import yaml + +from resin.chat_engine import ChatEngine +from resin.context_engine import ContextEngine +from resin.knowledge_base import KnowledgeBase + +DEFAULT_COFIG_PATH = 'config/config.yaml' + + +@pytest.fixture(scope='module') +def temp_index_name(): + index_name_before = os.getenv("INDEX_NAME", None) + + os.environ["INDEX_NAME"] = "temp_index" + yield "temp_index" + + if index_name_before is None: + del os.environ["INDEX_NAME"] + else: + os.environ["INDEX_NAME"] = index_name_before + + +def test_default_config_matches_code_defaults(temp_index_name): + with open(DEFAULT_COFIG_PATH) as f: + default_config = yaml.safe_load(f) + chat_engine_config = default_config['chat_engine'] + + loaded_chat_engine = ChatEngine.from_config(chat_engine_config) + default_kb = KnowledgeBase(index_name=temp_index_name) + default_context_engine = ContextEngine(default_kb) + default_chat_engine = ChatEngine(default_context_engine) + + def assert_identical_components(loaded_component, default_component): + assert type(loaded_component) == type(default_component) # noqa: E721 + if not loaded_component.__module__.startswith("resin"): + return + + for key, value in default_component.__dict__.items(): + assert hasattr(loaded_component, key), ( + f"Missing attribute {key} in {type(loaded_component)}" + ) + if hasattr(value, '__dict__'): + assert_identical_components(getattr(loaded_component, key), value) + else: + assert getattr(loaded_component, key) == value, ( + f"Attribute {key} in {type(loaded_component)} is {value} in code " + f"but {getattr(loaded_component, key)} in config" + ) + + assert_identical_components(loaded_chat_engine, default_chat_engine) diff --git a/tests/unit/chat_engine/test_chat_engine.py b/tests/unit/chat_engine/test_chat_engine.py index e14354c0..1a0e6dd5 100644 --- a/tests/unit/chat_engine/test_chat_engine.py +++ b/tests/unit/chat_engine/test_chat_engine.py @@ -19,7 +19,7 @@ class TestChatEngine: - def setup(self): + def setup_method(self): self.mock_llm = create_autospec(BaseLLM) self.mock_query_builder = create_autospec(QueryGenerator) self.mock_context_engine = create_autospec(ContextEngine) diff --git a/tests/unit/utils/test_config.py b/tests/unit/utils/test_config.py index d7ebdde8..1d8f5f53 100644 --- a/tests/unit/utils/test_config.py +++ b/tests/unit/utils/test_config.py @@ -1,4 +1,5 @@ # noqa: F405 + import pytest from ._stub_classes import (BaseStubChunker, StubChunker, StubKB,