diff --git a/litellm/utils.py b/litellm/utils.py index b1a0b13e7760..a82b4d0960c0 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -35,6 +35,7 @@ from dataclasses import dataclass, field from functools import lru_cache, wraps from importlib import resources +from importlib.metadata import entry_points from inspect import iscoroutine from io import StringIO from os.path import abspath, dirname, join @@ -379,10 +380,22 @@ def print_verbose( ####### CLIENT ################### # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking +def load_custom_provider_entrypoints(): + found_entry_points = tuple(entry_points().select(group="litellm")) # type: ignore + for entry_point in found_entry_points: + # types are ignored because of circular dependency issues importing CustomLLM and CustomLLMItem + HandlerClass = entry_point.load() + handler = HandlerClass() + provider = {"provider": entry_point.name, "custom_handler": handler} + litellm.custom_provider_map.append(provider) # type: ignore + + def custom_llm_setup(): """ Add custom_llm provider to provider list """ + load_custom_provider_entrypoints() + for custom_llm in litellm.custom_provider_map: if custom_llm["provider"] not in litellm.provider_list: litellm.provider_list.append(custom_llm["provider"]) diff --git a/tests/local_testing/test_custom_llm.py b/tests/local_testing/test_custom_llm.py index 61312f4e52f2..16d20984d5b8 100644 --- a/tests/local_testing/test_custom_llm.py +++ b/tests/local_testing/test_custom_llm.py @@ -10,6 +10,7 @@ import openai import pytest +from pytest_mock import MockerFixture sys.path.insert( 0, os.path.abspath("../..") @@ -536,4 +537,47 @@ async def test_simple_aembedding(): "object": "embedding", "embedding": [0.1, 0.2, 0.3], "index": 1, - } \ No newline at end of file + } + + +def test_custom_llm_provider_entrypoint(mocker: MockerFixture): + # This test mocks the use of entry-points in pyproject.toml: + # [project.entry-point.litellm] + # custom_llm = :MyCustomLLM + # another-custom-llm = :AnotherCustomLLM + + from litellm.utils import custom_llm_setup + + class AnotherCustomLLM(CustomLLM): + pass + + providers = { + "custom_llm": MyCustomLLM, + "another-custom-llm": AnotherCustomLLM + } + + def load(self): + return providers[self.name] + + mocker.patch("importlib.metadata.EntryPoint.load", load) + from importlib.metadata import EntryPoints, EntryPoint + + entry_points = EntryPoints([ + EntryPoint(group="litellm", name="custom_llm", value="package.module:MyCustomLLM"), + EntryPoint(group="litellm", name="another-custom-llm", value="package.module:AnotherCustomLLM"), + ]) + mocked = mocker.patch("litellm.utils.entry_points") + mocked.return_value = entry_points + + assert litellm.custom_provider_map == [] + assert litellm._custom_providers == [] + + custom_llm_setup() + + assert litellm._custom_providers == ['custom_llm', 'another-custom-llm'] + + assert litellm.custom_provider_map[0]["provider"] == "custom_llm" + assert isinstance(litellm.custom_provider_map[0]["custom_handler"], CustomLLM) + + assert litellm.custom_provider_map[1]["provider"] == "another-custom-llm" + assert isinstance(litellm.custom_provider_map[1]["custom_handler"], AnotherCustomLLM)