Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch scoring sdk updates #3659

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ inputs:
description: An optional configuration file to use for deployment settings. This overrides passed in parameters.
scoring_url:
type: string
optional: false
optional: true
description: The URL of the endpoint.
model_type:
type: string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,15 @@ inputs:
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount

teacher_model_endpoint_url:
type: string
optional: true
description: Teacher model endpoint URL

teacher_model_endpoint_name:
type: string
optional: true
description: Teacher model endpoint name

teacher_model_endpoint_key:
teacher_model_connection_name:
type: string
optional: true
description: Teacher model endpoint key
description: Teacher model connection name

teacher_model_max_new_tokens:
type: integer
Expand Down Expand Up @@ -199,20 +194,6 @@ inputs:
type: string
optional: true
description: Config file path that contains deployment configurations
additional_headers:
type: string
optional: True
description: JSON serialized string expressing additional headers to be added to each request.
debug_mode:
type: boolean
optional: False
default: False
description: Enable debug mode to print all the debug logs in the score step.
ensure_ascii:
type: boolean
optional: False
default: False
description: If set to true, the output is guaranteed to have all incoming non-ASCII characters escaped. If set to false, these characters will be output as-is. More detailed information can be found at https://docs.python.org/3/library/json.html
max_retry_time_interval:
type: integer
optional: True
Expand Down Expand Up @@ -266,8 +247,7 @@ jobs:
train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
teacher_model_connection_name: '${{parent.inputs.teacher_model_connection_name}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
max_len_summary: '${{parent.inputs.max_len_summary}}'
Expand Down Expand Up @@ -303,13 +283,9 @@ jobs:
identity:
type: user_identity
inputs:
scoring_url: ${{parent.inputs.teacher_model_endpoint_url}}
deployment_name: ${{parent.inputs.teacher_model_endpoint_name}}
authentication_type: ${{parent.inputs.authentication_type}}
configuration_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}}
additional_headers: ${{parent.inputs.additional_headers}}
debug_mode: ${{parent.inputs.debug_mode}}
ensure_ascii: ${{parent.inputs.ensure_ascii}}
max_retry_time_interval: ${{parent.inputs.max_retry_time_interval}}
initial_worker_count: ${{parent.inputs.initial_worker_count}}
max_worker_count: ${{parent.inputs.max_worker_count}}
Expand Down Expand Up @@ -408,6 +384,7 @@ jobs:
hash_validation_data: '${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.hash_validation_data}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
connection_config_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ type: command

is_deterministic: False

display_name: OSS Distillation Generate Data Postprocess Batch Scoring
display_name: OSS Distillation Generate Data Batch Scoring Postprocess
description: Component to prepare data returned from teacher model enpoint in batch

environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/76
Expand Down Expand Up @@ -62,6 +62,11 @@ inputs:
default: "false"
description: Enable Chain of density for text summarization

teacher_model_endpoint_name:
type: string
optional: true
description: Teacher model endpoint name

data_generation_task_type:
type: string
enum:
Expand Down Expand Up @@ -104,6 +109,7 @@ command: >-
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}}
$[[--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}]]
$[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]]
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]]
--data_generation_task_type ${{inputs.data_generation_task_type}}
--connection_config_file ${{inputs.connection_config_file}}
--generated_batch_train_file_path ${{outputs.generated_batch_train_file_path}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,10 @@ inputs:
optional: true
description: Teacher model endpoint name

teacher_model_endpoint_url:
teacher_model_connection_name:
type: string
optional: true
description: Teacher model endpoint url

teacher_model_endpoint_key:
type: string
optional: true
description: Teacher model endpoint key
description: Teacher model connection name

teacher_model_max_new_tokens:
type: integer
Expand Down Expand Up @@ -133,8 +128,7 @@ command: >-
--train_file_path ${{inputs.train_file_path}}
$[[--validation_file_path ${{inputs.validation_file_path}}]]
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]]
$[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]]
$[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]]
$[[--teacher_model_connection_name ${{inputs.teacher_model_connection_name}}]]
--teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}}
--teacher_model_temperature ${{inputs.teacher_model_temperature}}
--teacher_model_top_p ${{inputs.teacher_model_top_p}}
Expand Down
28 changes: 8 additions & 20 deletions assets/training/distillation/components/pipeline/spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ inputs:
optional: true
description: Teacher model endpoint name

teacher_model_connection_name:
type: string
optional: true
description: Teacher model connection name

teacher_model_endpoint_url:
type: string
optional: true
Expand Down Expand Up @@ -181,20 +186,6 @@ inputs:
enum:
- azureml_workspace_connection
- managed_identity
additional_headers:
type: string
optional: True
description: JSON serialized string expressing additional headers to be added to each request.
debug_mode:
type: boolean
optional: False
default: False
description: Enable debug mode to print all the debug logs in the score step.
ensure_ascii:
type: boolean
optional: False
default: False
description: If set to true, the output is guaranteed to have all incoming non-ASCII characters escaped. If set to false, these characters will be output as-is. More detailed information can be found at https://docs.python.org/3/library/json.html
max_retry_time_interval:
type: integer
optional: True
Expand All @@ -220,7 +211,7 @@ inputs:
mini_batch_size:
type: string
optional: true
default: 100KB
default: 10KB
description: The mini batch size for parallel run.

# ########################### Finetuning Component ########################### #
Expand Down Expand Up @@ -307,6 +298,7 @@ jobs:
train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_connection_name: '${{parent.inputs.teacher_model_connection_name}}'
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
Expand Down Expand Up @@ -359,9 +351,8 @@ jobs:
compute_finetune: '${{parent.inputs.compute_finetune}}'
train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}'
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
teacher_model_connection_name: '${{parent.inputs.teacher_model_connection_name}}'
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
Expand All @@ -377,9 +368,6 @@ jobs:
per_device_train_batch_size: '${{parent.inputs.per_device_train_batch_size}}'
learning_rate: '${{parent.inputs.learning_rate}}'
authentication_type: '${{parent.inputs.authentication_type}}'
additional_headers: '${{parent.inputs.additional_headers}}'
debug_mode: '${{parent.inputs.debug_mode}}'
ensure_ascii: '${{parent.inputs.ensure_ascii}}'
max_retry_time_interval: '${{parent.inputs.max_retry_time_interval}}'
initial_worker_count: '${{parent.inputs.initial_worker_count}}'
max_worker_count: '${{parent.inputs.max_worker_count}}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ inputs:
optional: true
description: Teacher model endpoint key

teacher_model_connection_name:
type: string
optional: true
description: Teacher model connection name

teacher_model_max_new_tokens:
type: integer
default: 128
Expand Down Expand Up @@ -148,6 +153,7 @@ command: >-
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]]
$[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]]
$[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]]
$[[--teacher_model_connection_name ${{inputs.teacher_model_connection_name}}]]
--teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}}
--teacher_model_temperature ${{inputs.teacher_model_temperature}}
--teacher_model_top_p ${{inputs.teacher_model_top_p}}
Expand Down
1 change: 1 addition & 0 deletions assets/training/distillation/src/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class TelemetryConstants:
POST_PROCESS_TRAINING_DATA = "post_process_training_data"
POST_PROCESS_VALIDATION_DATA = "post_process_validation_data"
PROCESS_DATASET_RECORD = "process_dataset_record"
DELETE_WORKSPACE_CONNECTION = "delete_workspace_connection"

VALIDATOR = "validator"
ML_CLIENT_INITIALISATION = "ml_client_initialisation"
Expand Down
20 changes: 16 additions & 4 deletions assets/training/distillation/src/generate_data_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ def get_parser():
choices=[v.value for v in DataGenerationTaskType],
)

parser.add_argument(
"--teacher_model_endpoint_name",
type=str,
required=False,
help="Teacher model endpoint name",
)

parser.add_argument(
"--min_endpoint_success_ratio",
type=float,
Expand Down Expand Up @@ -308,17 +315,22 @@ def data_import(args: Namespace):
hash_train_data = args.hash_train_data
hash_validation_data = args.hash_validation_data
connection_config_file = args.connection_config_file
teacher_model_endpoint_name = args.teacher_model_endpoint_name

enable_cot = True if enable_cot_str.lower() == "true" else False
enable_cod = True if enable_cod_str.lower() == "true" else False

if teacher_model_endpoint_name:
with log_activity(
logger=logger, activity_name=TelemetryConstants.DELETE_WORKSPACE_CONNECTION
):
logger.info(
"Deleting batch configuration connection when teacher model endpoint name is provided."
)
delete_connection(connection_config_file)
with log_activity(
logger=logger, activity_name=TelemetryConstants.POST_PROCESS_TRAINING_DATA
):
logger.info(
"Deleting batch configuration connection used for teacher model invocation."
)
delete_connection(connection_config_file)
logger.info(
"Running data postprocessing for train file path: %s", train_file_path
)
Expand Down
Loading
Loading