Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from dataclasses import dataclass, field
from functools import lru_cache, wraps
from importlib import resources
from importlib.metadata import entry_points
from inspect import iscoroutine
from io import StringIO
from os.path import abspath, dirname, join
Expand Down Expand Up @@ -379,10 +380,22 @@ def print_verbose(

####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def load_custom_provider_entrypoints():
found_entry_points = tuple(entry_points().select(group="litellm")) # type: ignore
for entry_point in found_entry_points:
# types are ignored because of circular dependency issues importing CustomLLM and CustomLLMItem
HandlerClass = entry_point.load()
handler = HandlerClass()
provider = {"provider": entry_point.name, "custom_handler": handler}
litellm.custom_provider_map.append(provider) # type: ignore


def custom_llm_setup():
"""
Add custom_llm provider to provider list
"""
load_custom_provider_entrypoints()

for custom_llm in litellm.custom_provider_map:
if custom_llm["provider"] not in litellm.provider_list:
litellm.provider_list.append(custom_llm["provider"])
Expand Down
46 changes: 45 additions & 1 deletion tests/local_testing/test_custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import openai
import pytest
from pytest_mock import MockerFixture

sys.path.insert(
0, os.path.abspath("../..")
Expand Down Expand Up @@ -536,4 +537,47 @@ async def test_simple_aembedding():
"object": "embedding",
"embedding": [0.1, 0.2, 0.3],
"index": 1,
}
}


def test_custom_llm_provider_entrypoint(mocker: MockerFixture):
# This test mocks the use of entry-points in pyproject.toml:
# [project.entry-point.litellm]
# custom_llm = <module>:MyCustomLLM
# another-custom-llm = <module>:AnotherCustomLLM

from litellm.utils import custom_llm_setup

class AnotherCustomLLM(CustomLLM):
pass

providers = {
"custom_llm": MyCustomLLM,
"another-custom-llm": AnotherCustomLLM
}

def load(self):
return providers[self.name]

mocker.patch("importlib.metadata.EntryPoint.load", load)
from importlib.metadata import EntryPoints, EntryPoint

entry_points = EntryPoints([
EntryPoint(group="litellm", name="custom_llm", value="package.module:MyCustomLLM"),
EntryPoint(group="litellm", name="another-custom-llm", value="package.module:AnotherCustomLLM"),
])
mocked = mocker.patch("litellm.utils.entry_points")
mocked.return_value = entry_points

assert litellm.custom_provider_map == []
assert litellm._custom_providers == []

custom_llm_setup()

assert litellm._custom_providers == ['custom_llm', 'another-custom-llm']

assert litellm.custom_provider_map[0]["provider"] == "custom_llm"
assert isinstance(litellm.custom_provider_map[0]["custom_handler"], CustomLLM)

assert litellm.custom_provider_map[1]["provider"] == "another-custom-llm"
assert isinstance(litellm.custom_provider_map[1]["custom_handler"], AnotherCustomLLM)
Loading