Skip to content

Commit

Permalink
Generator Request Tags (#582)
Browse files Browse the repository at this point in the history
This PR provides the structure for passing Tags into Generation
requests. The Tags that are passed will be applied by the Engine to all
Blocks which are created as part of the generate request.
  • Loading branch information
dkolas authored Oct 13, 2023
1 parent 9c4ac74 commit 39d1739
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/steamship/data/operations/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from steamship.base.request import Request
from steamship.base.response import Response
from steamship.data.block import Block
from steamship.data.tags.tag import Tag


class GenerateRequest(Request):
Expand Down Expand Up @@ -67,6 +68,10 @@ class GenerateRequest(Request):
# Default behavior if not provided is streaming=false
streaming: Optional[bool] = None

# Tags which will be applied to all Blocks that are generated as part of
# this request.
tags: Optional[List[Tag]] = None


class GenerateResponse(Response):
blocks: List[Block]
3 changes: 3 additions & 0 deletions src/steamship/data/plugin/plugin_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
HostingTimeout,
HostingType,
)
from steamship.data.tags.tag import Tag
from steamship.plugin.inputs.export_plugin_input import ExportPluginInput
from steamship.plugin.inputs.training_parameter_plugin_input import TrainingParameterPluginInput
from steamship.plugin.outputs.train_plugin_output import TrainPluginOutput
Expand Down Expand Up @@ -131,6 +132,7 @@ def generate(
make_output_public: Optional[bool] = None,
options: Optional[dict] = None,
streaming: Optional[bool] = None,
tags: Optional[List[Tag]] = None,
) -> Task[GenerateResponse]:
"""See GenerateRequest for description of parameter options"""
req = GenerateRequest(
Expand All @@ -148,6 +150,7 @@ def generate(
make_output_public=make_output_public,
options=options,
streaming=streaming,
tags=tags,
)
return self.client.post("plugin/instance/generate", req, expect=GenerateResponse)

Expand Down
27 changes: 26 additions & 1 deletion tests/steamship_tests/plugin/integration/test_e2e_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from steamship_tests.utils.deployables import deploy_plugin
from steamship_tests.utils.fixtures import get_steamship_client

from steamship import Block, File, MimeTypes, PluginInstance, Steamship
from steamship import Block, File, MimeTypes, PluginInstance, Steamship, Tag


def test_e2e_generator():
Expand Down Expand Up @@ -218,3 +218,28 @@ def test_generate_block_private_data(client: Steamship):

assert response.text == "PRETEND THIS IS THE DATA OF AN IMAGE"
assert response.headers["content-type"] == MimeTypes.PNG


@pytest.mark.usefixtures("client")
def test_generation_with_tags(client: Steamship):
plugin_instance = PluginInstance.create(client, plugin_handle="test-generator")
file = File.create(client, blocks=[Block(text="One block"), Block(text="two blocks")])
tags = [
Tag(kind="test_kind_0", name="test_name_0", value={"test_value": "test_value_0"}),
Tag(kind="test_kind_1", name="test_name_1", value={"test_value": "test_value_1"}),
]
generate_task = plugin_instance.generate(
input_file_id=file.id, append_output_to_file=True, tags=tags
)
blocks = generate_task.wait().blocks
assert blocks is not None
assert len(blocks) == 2
for block in blocks:
result_tags = block.tags
assert result_tags is not None
assert len(result_tags) == 2

Check failure on line 240 in tests/steamship_tests/plugin/integration/test_e2e_generator.py

View workflow job for this annotation

GitHub Actions / Production Test Results

test_e2e_generator.test_generation_with_tags

assert == failed. [pytest-clarity diff shown] #x1B[0m #x1B[0m#x1B[32mLHS#x1B[0m vs #x1B[31mRHS#x1B[0m shown below #x1B[0m #x1B[0m#x1B[32m0#x1B[0m #x1B[0m#x1B[31m2#x1B[0m #x1B[0m
Raw output
client = Steamship(config=Configuration(api_key=SecretStr('**********'), api_base=AnyHttpUrl('https://api.steamship.com/api/v1/...kspace_id='C42D9A97-1B3D-4D34-A311-B3CD41651E76', workspace_handle='test_h2vrreabsz', profile='test', request_id=None))

    @pytest.mark.usefixtures("client")
    def test_generation_with_tags(client: Steamship):
        plugin_instance = PluginInstance.create(client, plugin_handle="test-generator")
        file = File.create(client, blocks=[Block(text="One block"), Block(text="two blocks")])
        tags = [
            Tag(kind="test_kind_0", name="test_name_0", value={"test_value": "test_value_0"}),
            Tag(kind="test_kind_1", name="test_name_1", value={"test_value": "test_value_1"}),
        ]
        generate_task = plugin_instance.generate(
            input_file_id=file.id, append_output_to_file=True, tags=tags
        )
        blocks = generate_task.wait().blocks
        assert blocks is not None
        assert len(blocks) == 2
        for block in blocks:
            result_tags = block.tags
            assert result_tags is not None
>           assert len(result_tags) == 2
E           assert == failed. [pytest-clarity diff shown]
E             #x1B[0m
E             #x1B[0m#x1B[32mLHS#x1B[0m vs #x1B[31mRHS#x1B[0m shown below
E             #x1B[0m
E             #x1B[0m#x1B[32m0#x1B[0m
E             #x1B[0m#x1B[31m2#x1B[0m
E             #x1B[0m

tests/steamship_tests/plugin/integration/test_e2e_generator.py:240: AssertionError
for i in range(2):
result_tag = result_tags[i]
assert result_tag.kind == f"test_kind_{i}"
assert result_tag.name == f"test_name_{i}"
assert result_tag.value == {"test_value": f"test_value_{i}"}

0 comments on commit 39d1739

Please sign in to comment.