Skip to content

Commit

Permalink
fix: Fix capability name/NAME and add new test (#591)
Browse files Browse the repository at this point in the history
Previous unit tests hid pretty well the fact that NAME wasn't set on capabilities.
  • Loading branch information
GitOnUp authored Oct 24, 2023
1 parent 69a9964 commit 26360ec
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/steamship/plugin/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ class SystemPromptSupport(CapabilityImpl):
The system prompt will come across in other blocks on the request.
"""

name = "steamship.system_prompt_support"
NAME = "steamship.system_prompt_support"

request_level = RequestLevel.BEST_EFFORT
"""
Expand All @@ -324,7 +324,7 @@ class ConversationSupport(CapabilityImpl):
The content of the conversation will come across in other blocks on the request, using the CHAT TagKind.
"""

name = "steamship.conversation_support"
NAME = "steamship.conversation_support"


class FunctionCallingSupport(CapabilityImpl):
Expand All @@ -335,7 +335,7 @@ class FunctionCallingSupport(CapabilityImpl):
following request.
"""

name = "steamship.function_calling_support"
NAME = "steamship.function_calling_support"

functions: List[Tool]
"""A list of Tools which the LLM can choose from to execute."""
Expand Down
22 changes: 22 additions & 0 deletions tests/steamship_tests/plugin/unit/test_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
CapabilityImpl,
CapabilityPluginRequest,
CapabilityPluginResponse,
ConversationSupport,
FunctionCallingSupport,
RequestedCapabilities,
RequestLevel,
SupportLevel,
SystemPromptSupport,
UnsupportedCapabilityError,
)

Expand Down Expand Up @@ -57,6 +60,25 @@ def test_is_plugin_support_valid(
assert response.fulfilled_at in support_level


def test_builtin_capabilities_support():
original_capabilities = {
ConversationSupport: ConversationSupport(),
SystemPromptSupport: SystemPromptSupport(),
FunctionCallingSupport: FunctionCallingSupport(functions=[]),
}
original = CapabilityPluginRequest(requested_capabilities=list(original_capabilities.values()))
block = original.to_block()
roundtripped = CapabilityPluginRequest.from_block(block)
assert original == roundtripped
requested_capabilities = RequestedCapabilities(
{cap_typ: SupportLevel.NATIVE for cap_typ in original_capabilities.keys()}
)
requested_capabilities.load_requests(roundtripped)
for cap_typ in original_capabilities.keys():
requested = requested_capabilities.get(cap_typ)
assert requested == original_capabilities[cap_typ]


def test_capability_plugin_request_block_roundtrips():
original = CapabilityPluginRequest(
requested_capabilities=[TestCapability(request_level=RequestLevel.NATIVE)]
Expand Down

0 comments on commit 26360ec

Please sign in to comment.