Skip to content

Commit

Permalink
Stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
eob committed Apr 30, 2022
1 parent 296a295 commit 1b58676
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 8 deletions.
10 changes: 10 additions & 0 deletions src/steamship/data/plugin_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from steamship.base import Client, Request
from steamship.base.response import Response
from steamship.plugin.inputs.export_plugin_input import ExportPluginInput
from steamship.plugin.inputs.training_parameter_plugin_input import TrainingParameterPluginInput
from steamship.plugin.outputs.raw_data_plugin_output import RawDataPluginOutput
from steamship.plugin.outputs.training_parameter_plugin_output import TrainingParameterPluginOutput


Expand Down Expand Up @@ -92,6 +94,14 @@ def delete(self) -> PluginInstance:
expect=PluginInstance
)

def export(self, input: ExportPluginInput) -> RawDataPluginOutput:
input.pluginInstance = self.handle
return self.client.post(
'plugin/instance/export',
payload=input,
expect=RawDataPluginOutput
)

def train(self, trainingRequest: TrainingParameterPluginInput) -> PluginInstance:
return self.client.post(
'plugin/instance/train',
Expand Down
8 changes: 4 additions & 4 deletions src/steamship/plugin/inputs/export_plugin_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def from_dict(d: any = None, client: Client = None) -> "ExportPluginInput":

return ExportPluginInput(
pluginInstance = d.get('pluginInstance', None),
id = d.get(': str = None', None),
handle = d.get(': str = None', None),
type = d.get(': str = None', None),
filename = d.get(': str = None', None)
id = d.get('id', None),
handle = d.get('id', None),
type = d.get('type', None),
filename = d.get('filename', None)
)

def to_dict(self) -> Dict:
Expand Down
4 changes: 2 additions & 2 deletions src/steamship/plugin/inputs/train_plugin_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def from_dict(d: any = None, client: Client = None) -> "TrainPluginInput":
return None

return TrainPluginInput(
tenantId = d.get(': str = None', None),
spaceId = d.get(': str = None', None),
tenantId = d.get('tenantId', None),
spaceId = d.get('spaceId', None),
pluginInstance=d.get('pluginInstance', None),
pluginInstanceId=d.get('pluginInstanceId', None),
modelName = d.get('modelName', None),
Expand Down
4 changes: 2 additions & 2 deletions src/steamship/plugin/inputs/train_plugin_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def from_dict(d: any = None, client: Client = None) -> "TrainPluginOutput":
return None

return TrainPluginOutput(
tenantId = d.get(': str = None', None),
spaceId = d.get(': str = None', None),
tenantId = d.get('tenantId', None),
spaceId = d.get('spaceId', None),
modelName = d.get('modelName', None),
modelFilename = d.get('modelFilename', None),
modelUploadUrl = d.get('modelUploadUrl', None),
Expand Down
55 changes: 55 additions & 0 deletions tests/plugin/test_e2e_corpus_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from dataclasses import asdict

from steamship.data.plugin_instance import PluginInstance
from steamship.data.plugin import TrainingPlatform
from steamship.extension.file import File
from steamship.plugin.inputs.export_plugin_input import ExportPluginInput
from steamship.plugin.inputs.training_parameter_plugin_input import TrainingParameterPluginInput
import time

from ..client.helpers import deploy_plugin, upload_file, _steamship

__copyright__ = "Steamship"
__license__ = "MIT"

EXPORTER_HANDLE = "signed-url-exporter"

def test_e2e_corpus_export():
client = _steamship()
versionConfigTemplate = dict(
textColumn=dict(type="string"),
tagColumns=dict(type="string"),
tagKind=dict(type="string")
)
instanceConfig = dict(
textColumn="Message",
tagColumns="Category",
tagKind="Intent"
)
exporterPluginR = PluginInstance.create(
client=client,
handle=EXPORTER_HANDLE,
pluginHandle=EXPORTER_HANDLE,
upsert=True
)
assert (exporterPluginR.data is not None)
exporterPlugin = exporterPluginR.data
assert (exporterPlugin.handle is not None)

input = ExportPluginInput(handle='default', type="corpus")
print(asdict(input))

# Make a blockifier which will generate our training corpus
with deploy_plugin("plugin_blockifier_csv.py", "blockifier", versionConfigTemplate=versionConfigTemplate, instanceConfig=instanceConfig) as (plugin, version, instance):
with upload_file("utterances.csv") as file:
assert (len(file.query().data.blocks) == 0)
# Use the plugin we just registered
file.blockify(pluginInstance=instance.handle).wait()
assert (len(file.query().data.blocks) == 5)

# Now export the corpus
rawDataR = exporterPlugin.export(input)
assert (rawDataR is not None)

# The results of a corpus exporter are MD5 encoded!
rawData = rawDataR.data
1 change: 1 addition & 0 deletions tests/plugin/test_e2e_trainable_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

def test_e2e_trainable_tagger_ecs_training():
client = _steamship()

versionConfigTemplate = dict(
textColumn=dict(type="string"),
tagColumns=dict(type="string"),
Expand Down

0 comments on commit 1b58676

Please sign in to comment.