Skip to content

Commit

Permalink
Add isTrainable field to plugin; align with new simdex code (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkolas authored Apr 18, 2022
1 parent 639a8e0 commit b8fb8ea
Show file tree
Hide file tree
Showing 13 changed files with 211 additions and 181 deletions.
31 changes: 6 additions & 25 deletions src/steamship/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from steamship.client.operations.tagger import TagRequest
from steamship.client.tasks import Tasks
from steamship.data import File
from steamship.data.embeddings import EmbedAndSearchRequest, EmbedAndSearchResponse, EmbeddingIndex
from steamship.data.embeddings import EmbedAndSearchRequest, QueryResults, EmbeddingIndex
from steamship.data.search import Hit
from steamship.data.space import Space

__copyright__ = "Steamship"
__license__ = "MIT"

from steamship.extension.file import TagResponse
from steamship.plugin.outputs.block_and_tag_plugin_output import BlockAndTagPluginOutput
from steamship.plugin.outputs.embedded_items_plugin_output import EmbeddedItemsPluginOutput

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -128,27 +130,6 @@ def scrape(
space=space
)

def embed(
self,
docs: List[str],
pluginInstance: str,
spaceId: str = None,
spaceHandle: str = None,
space: Space = None
) -> Response[EmbeddedItemsPluginOutput]:
req = EmbedRequest(
docs=docs,
pluginInstance=pluginInstance
)
return self.post(
'embedding/create',
req,
expect=EmbeddedItemsPluginOutput,
spaceId=spaceId,
spaceHandle=spaceHandle,
space=space
)

def embed_and_search(
self,
query: str,
Expand All @@ -158,17 +139,17 @@ def embed_and_search(
spaceId: str = None,
spaceHandle: str = None,
space: Space = None
) -> Response[EmbedAndSearchResponse]:
) -> Response[QueryResults]:
req = EmbedAndSearchRequest(
query=query,
docs=docs,
pluginInstance=pluginInstance,
k=k
)
return self.post(
'embedding/search',
'plugin/instance/embeddingSearch',
req,
expect=EmbedAndSearchResponse,
expect=QueryResults,
spaceId=spaceId,
spaceHandle=spaceHandle,
space=space
Expand Down
52 changes: 31 additions & 21 deletions src/steamship/data/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from dataclasses import dataclass
from typing import List, Dict, Union
from typing import List, Dict, Union, TypeVar, Generic

from steamship.base import Client, Request, Response, metadata_to_str
from steamship.data.search import Hit
Expand All @@ -13,18 +13,40 @@ class EmbedAndSearchRequest(Request):
pluginInstance: str
k: int = 1


#TODO: These types are not generics like the Swift QueryResult/QueryResults.
@dataclass
class QueryResult():
value: Hit
score: float
index: int
id: str

@staticmethod
def from_dict(d: any, client: Client = None) -> "QueryResult":
value = Hit.from_dict(d.get("value", {}))
return QueryResult(
value = value,
score = d.get('score'),
index = d.get('index'),
id = d.get('id')
)

@dataclass
class EmbedAndSearchResponse(Request):
hits: List[Hit] = None
class QueryResults(Request):
items: List[QueryResult] = None

@staticmethod
def from_dict(d: any, client: Client = None) -> "EmbedAndSearchResponse":
hits = [Hit.from_dict(h) for h in (d.get("hits", []) or [])]
return EmbedAndSearchResponse(
hits=hits
def from_dict(d: any, client: Client = None) -> "QueryResults":
items = [QueryResult.from_dict(h) for h in (d.get("items", []) or [])]
return QueryResults(
items=items
)





@dataclass
class EmbeddedItem:
id: str = None
Expand Down Expand Up @@ -145,18 +167,6 @@ class IndexSearchRequest(Request):
includeMetadata: bool = False


@dataclass
class IndexSearchResponse:
hits: List[Hit] = None

@staticmethod
def from_dict(d: any, client: Client = None) -> "IndexSearchResponse":
hits = [Hit.from_dict(h) for h in (d.get("hits", []) or [])]
return IndexSearchResponse(
hits=hits
)


@dataclass
class IndexSnapshotRequest(Request):
indexId: str
Expand Down Expand Up @@ -467,7 +477,7 @@ def search(
spaceId: str = None,
spaceHandle: str = None,
space: any = None
) -> Response[IndexSearchResponse]:
) -> Response[QueryResults]:
if type(query) == list:
req = IndexSearchRequest(
self.id,
Expand All @@ -487,7 +497,7 @@ def search(
ret = self.client.post(
'embedding-index/search',
req,
expect=IndexSearchResponse,
expect=QueryResults,
spaceId=spaceId,
spaceHandle=spaceHandle,
space=space
Expand Down
5 changes: 4 additions & 1 deletion src/steamship/data/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Plugin:

@dataclass
class CreatePluginRequest(Request):
isTrainable: bool
id: str = None
name: str = None
type: str = None
Expand Down Expand Up @@ -69,9 +70,9 @@ class GetPluginRequest(Request):


class PluginType:
embedder = "embedder"
parser = "parser"
classifier = "classifier"
tagger = "tagger"


class PluginAdapterType:
Expand Down Expand Up @@ -133,6 +134,7 @@ def from_dict(d: any, client: Client = None) -> "Plugin":
@staticmethod
def create(
client: Client,
isTrainable: bool,
name: str,
description: str,
type: str,
Expand All @@ -152,6 +154,7 @@ def create(
metadata = json.dumps(metadata)

req = CreatePluginRequest(
isTrainable=isTrainable,
name=name,
type=type,
transport=transport,
Expand Down
1 change: 1 addition & 0 deletions src/steamship/data/tags/text_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ class TextTag:
isOov = "isOov"
isStop = "isStop"
lang = "lang"
embedding = "embedding"
4 changes: 2 additions & 2 deletions tests/client/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ def deploy_app(py_name: str, versionConfigTemplate : Dict[str, any] = None, inst


@contextlib.contextmanager
def deploy_plugin(py_name: str, plugin_type: str, versionConfigTemplate : Dict[str, any] = None, instanceConfig : Dict[str, any] = None):
def deploy_plugin(py_name: str, plugin_type: str, versionConfigTemplate : Dict[str, any] = None, instanceConfig : Dict[str, any] = None, isTrainable: bool = False):
client = _steamship()
name = _random_name()
plugin = Plugin.create(client, name=name, description='test', type=plugin_type, transport="jsonOverHttp",
plugin = Plugin.create(client, isTrainable=isTrainable, name=name, description='test', type=plugin_type, transport="jsonOverHttp",
isPublic=False)
assert (plugin.error is None)
assert (plugin.data is not None)
Expand Down
36 changes: 24 additions & 12 deletions tests/client/operations/test_embed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from steamship import PluginInstance
from steamship import PluginInstance, File
from steamship.base import Client

from tests.client.helpers import _steamship
Expand All @@ -8,19 +8,31 @@

_TEST_EMBEDDER = "test-embedder"

def count_embeddings(file: File):
embeddings = 0
for block in file.blocks:
for tag in block.tags:
if tag.kind == 'text' and tag.name == 'embedding':
embeddings += 1
return embeddings

def basic_embeddings(steamship: Client, pluginInstance: str):
e1 = steamship.embed(["This is a test"], pluginInstance=pluginInstance)
e1b = steamship.embed(["Banana"], pluginInstance=pluginInstance)
assert (len(e1.data.embeddings) == 1)
assert (len(e1.data.embeddings[0]) > 1)
e1 = steamship.tag("This is a test", pluginInstance=pluginInstance)
e1b = steamship.tag("Banana", pluginInstance=pluginInstance)
e1.wait()
e1b.wait()
assert (count_embeddings(e1.data.file) == 1)
assert (count_embeddings(e1b.data.file) == 1)
assert (len(e1.data.file.blocks[0].tags[0].value['embedding']) > 1)

e2 = steamship.embed(["This is a test"], pluginInstance=pluginInstance)
assert (len(e2.data.embeddings) == 1)
assert (len(e2.data.embeddings[0]) == len(e1.data.embeddings[0]))
e2 = steamship.tag("This is a test", pluginInstance=pluginInstance)
e2.wait()
assert (count_embeddings(e2.data.file) == 1)
assert (len(e2.data.file.blocks[0].tags[0].value['embedding']) == len(e1.data.file.blocks[0].tags[0].value['embedding']))

e4 = steamship.embed(["This is a test"], pluginInstance=pluginInstance)
assert (len(e4.data.embeddings) == 1)
e4 = steamship.tag("This is a test", pluginInstance=pluginInstance)
e4.wait()
assert (count_embeddings(e4.data.file) == 1)


def test_basic_embeddings():
Expand All @@ -38,8 +50,8 @@ def basic_embedding_search(steamship: Client, pluginInstance: str):
]
query = "Who should I talk to about new employee setup?"
results = steamship.embed_and_search(query, docs, pluginInstance=pluginInstance)
assert (len(results.data.hits) == 1)
assert (results.data.hits[0].value == "Jonathan can help you with new employee onboarding")
assert (len(results.data.items) == 1)
assert (results.data.items[0].value.value == "Jonathan can help you with new employee onboarding")


def test_basic_embedding_search():
Expand Down
20 changes: 10 additions & 10 deletions tests/client/operations/test_embed_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def test_file_parse():
embedResp.wait()

res = index.search("What color are roses?").data
assert (len(res.hits) == 1)
assert (len(res.items) == 1)
# Because the simdex now indexes entire blocks and not sentences, the result of this is the whole block text
assert (res.hits[0].value == " ".join([P1_1, P1_2]))
assert (res.items[0].value.value == " ".join([P1_1, P1_2]))

a.delete()

Expand Down Expand Up @@ -115,14 +115,14 @@ def test_file_index():
index = a.index(pluginInstance=embedder.handle)

res = index.search("What color are roses?").data
assert (len(res.hits) == 1)
assert (len(res.items) == 1)
# Because the simdex now indexes entire blocks and not sentences, the result of this is the whole block text
assert (res.hits[0].value == " ".join([P1_1, P1_2]))
assert (res.items[0].value.value == " ".join([P1_1, P1_2]))

res = index.search("What flavors does cake come in?").data
assert (len(res.hits) == 1)
assert (len(res.items) == 1)
# Because the simdex now indexes entire blocks and not sentences, the result of this is the whole block text
assert (res.hits[0].value == " ".join([P4_1, P4_2]))
assert (res.items[0].value.value == " ".join([P4_1, P4_2]))

index.delete()
a.delete()
Expand Down Expand Up @@ -172,12 +172,12 @@ def test_file_embed_lookup():
index.insert_file(b.id, blockType='sentence', reindex=True)

res = index.search("What does Ted like to do?").data
assert (len(res.hits) == 1)
assert (res.hits[0].value == content_a)
assert (len(res.items) == 1)
assert (res.items[0].value.value == content_a)

res = index.search("What does Grace like to do?").data
assert (len(res.hits) == 1)
assert (res.hits[0].value == content_b)
assert (len(res.items) == 1)
assert (res.items[0].value.value == content_b)

# Now we list the items
itemsa = index.list_items(fileId=a.id).data
Expand Down
Loading

0 comments on commit b8fb8ea

Please sign in to comment.