11from azureml .pipeline .core .graph import PipelineParameter
22from azureml .pipeline .steps import PythonScriptStep
33from azureml .pipeline .core import Pipeline , PipelineData
4- from azureml .core import Workspace
4+ from azureml .core import Workspace , Dataset , Datastore
55from azureml .core .runconfig import RunConfiguration
6- from azureml .core import Dataset
76from ml_service .util .attach_compute import get_compute
87from ml_service .util .env_variables import Env
98from ml_service .util .manage_environment import get_environment
@@ -39,8 +38,20 @@ def main():
3938 run_config = RunConfiguration ()
4039 run_config .environment = environment
4140
41+ if (e .datastore_name ):
42+ datastore_name = e .datastore_name
43+ else :
44+ datastore_name = aml_workspace .get_default_datastore ().name
45+ run_config .environment .environment_variables ["DATASTORE_NAME" ] = datastore_name # NOQA: E501
46+
4247 model_name_param = PipelineParameter (
4348 name = "model_name" , default_value = e .model_name )
49+ dataset_version_param = PipelineParameter (
50+ name = "dataset_version" , default_value = e .dataset_version )
51+ data_file_path_param = PipelineParameter (
52+ name = "data_file_path" , default_value = "none" )
53+ caller_run_id_param = PipelineParameter (
54+ name = "caller_run_id" , default_value = "none" )
4455
4556 # Get dataset name
4657 dataset_name = e .dataset_name
@@ -57,9 +68,9 @@ def main():
5768 df .to_csv (file_name , index = False )
5869
5970 # Upload file to default datastore in workspace
60- default_ds = aml_workspace . get_default_datastore ( )
71+ datatstore = Datastore . get ( aml_workspace , datastore_name )
6172 target_path = 'training-data/'
62- default_ds .upload_files (
73+ datatstore .upload_files (
6374 files = [file_name ],
6475 target_path = target_path ,
6576 overwrite = True ,
@@ -68,17 +79,14 @@ def main():
6879 # Register dataset
6980 path_on_datastore = os .path .join (target_path , file_name )
7081 dataset = Dataset .Tabular .from_delimited_files (
71- path = (default_ds , path_on_datastore ))
82+ path = (datatstore , path_on_datastore ))
7283 dataset = dataset .register (
7384 workspace = aml_workspace ,
7485 name = dataset_name ,
7586 description = 'diabetes training data' ,
7687 tags = {'format' : 'CSV' },
7788 create_new_version = True )
7889
79- # Get the dataset
80- dataset = Dataset .get_by_name (aml_workspace , dataset_name )
81-
8290 # Create a PipelineData to pass data between steps
8391 pipeline_data = PipelineData (
8492 'pipeline_data' ,
@@ -89,11 +97,14 @@ def main():
8997 script_name = e .train_script_path ,
9098 compute_target = aml_compute ,
9199 source_directory = e .sources_directory_train ,
92- inputs = [dataset .as_named_input ('training_data' )],
93100 outputs = [pipeline_data ],
94101 arguments = [
95102 "--model_name" , model_name_param ,
96- "--step_output" , pipeline_data
103+ "--step_output" , pipeline_data ,
104+ "--dataset_version" , dataset_version_param ,
105+ "--data_file_path" , data_file_path_param ,
106+ "--caller_run_id" , caller_run_id_param ,
107+ "--dataset_name" , dataset_name ,
97108 ],
98109 runconfig = run_config ,
99110 allow_reuse = False ,
0 commit comments