Skip to content

Commit

Permalink
Add summarization for distillation (#3454)
Browse files Browse the repository at this point in the history
* Add summarization for distillation

* reformat

* Remove large file

* Fix comments

* Update one more comment
  • Loading branch information
sanchez-alex authored Dec 5, 2024
1 parent 7fea839 commit 0c0b38f
Show file tree
Hide file tree
Showing 11 changed files with 1,213 additions and 5 deletions.
33 changes: 33 additions & 0 deletions cli/foundation-models/system/distillation/summarization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Distillation with CLI (Summarization)

## 1. Create the Job
Ensure you have the proper setup.
1. Run `az version` and ensure the `ml` extension is installed. `ml` version should be greater or equal to 2.32.0.
2. If the `ml` extension is not installed, run `az extension add -n ml`

Run the Distillation CLI command pointing to the .YAML file in this folder and fill out the Azure ML IDs needed:

```text
az ml job create --file distillation_summarization.yaml --workspace-name [YOUR_AZURE_WORKSPACE] --resource-group [YOUR_AZURE_RESOURCE_GROUP] --subscription [YOUR_AZURE_SUBSCRIPTION]
```

**Note:** To see how the train and validation files were created, see section 2 of this [notebook](/sdk/python/foundation-models/system/distillation/summarization/distillation_summarization.ipynb)

## 2. Deploy to Endpoint
Once the distilled model is ready, you can deploy the model through the UI or CLI.

### UI Deployment
1. Navigate to the `model` tab in [ml studio](https://ml.azure.com) or navigate to the `Finetuning` tab in the [ai platform](https://ai.azure.com)
2. If using the ml studio, locate the model using the `name` of the `registered_model` in the yaml file used to create this job. Select deploy to deploy a serverless endpoint. If using the ai platform, search for the name of the job, which in this example is `Distillation-summarization-llama`. Click on that name, and select Deploy to deploy a serverless endpoint.

### CLI Deployment
Fill out the serverless_endpoint.yaml file in this folder. The necessary information can be found by
1. Navigating to the `model` tab in [ml studio](https://ml.azure.com).
2. Using the `name` of the `registered_model` in the yaml file used to create this job, select the model with that `name`. In this example, the name to use is `llama-summarization-distilled`
3. Use the `asset_id` to fill out the `model_id` in the yaml.

With the information filled out, run the command

```text
az ml serverless-endpoint create -f serverless_endpoint.yaml
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
type: distillation

name: "Distillation-summarization-llama"
description: "Distill student model using a teacher model"
experiment_name: "Distillation-summarization"

# Data Generation Properties
data_generation_type: label_generation
data_generation_task_type: summarization

# Input data
training_data:
type: uri_file
path: ./train_summarization.jsonl
validation_data:
type: uri_file
path: ./validation_summarization.jsonl

# Teacher model serverless endpoint information
teacher_model_endpoint_connection:
type: serverless
name: Meta-Llama-3-1-405B-Instruct-vkn
endpoint: https://Meta-Llama-3-1-405B-Instruct-vkn.westus3.models.ai.azure.com/chat/completions
api_key: EXAMPLE_API_KEY

# Model ID
student_model: azureml://registries/azureml-meta/models/Meta-Llama-3.1-8B-Instruct/versions/2

# Output distilled model
outputs:
registered_model:
type: mlflow_model
name: llama-summarization-distilled


# Teacher model related properties (OPTIONAL)
teacher_model_settings:
inference_parameters:
temperature: 0.8
max_tokens: 1024
top_p: 0.95
endpoint_request_settings:
request_batch_size: 10
min_endpoint_success_ratio: 0.7

# System prompt settings (OPTIONAL)
prompt_settings:
enable_chain_of_density: true
max_len_summary: 80

# For finetuning (OPTIONAL)
hyperparameters:
learning_rate_multiplier: "0.2"
n_epochs: "5"
batch_size: "2"

# Resource for Data Generation Step (OPTIONAL)
resources:
instance_type: Standard_D2_v2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
name: llama-summarization-distilled
model_id: azureml://locations/{AI_PROJECT_LOCATION}/workspaces/{WORKSPACE_ID}/models/llama-summarization-distilled/versions/{VERSION}

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions sdk/python/foundation-models/system/distillation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ We currently support numerous task types for model distillation. To view example
- [Conversation](./conversation/distillation_conversational_task.ipynb)
- [NLU QA (Natural Language Understanding Question and Answer)](./nlu_qa/distillation_nlu_qa_task.ipynb)
- [Math](./math/distillation_math.ipynb)
- [Summarization](./summarization/distillation_summarization.ipynb)

### CLI Examples
- [NLI (Natural Language Interpretation)](/cli/foundation-models/system/distillation/nli/README.md)
- [Conversation](/cli/foundation-models/system/distillation/conversation/README.md)
- [NLU QA (Natural Language Understanding Question and Answer)](/cli/foundation-models/system/distillation/nlu_qa/README.md)
- [Math](/cli/foundation-models/system/distillation/math/README.md)
- [Summarization](/cli/foundation-models/system/distillation/summarization/README.md)
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@
"\n",
"# If validation data was registered to workspace already, navigate to the Data tab, select the data to use and use the 'Named asset URI'.\n",
"# Example of the format is seen below\n",
"# train_data = \"azureml:validation_conversation_quora:1\""
"# valid_data = \"azureml:validation_conversation_quora:1\""
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@
"valid_data = Input(type=AssetTypes.URI_FILE, path=valid_data_path)\n",
"\n",
"# If validation data was registered to workspace already, navigate to the Data tab, select the data to use and use the 'Named asset URI'\n",
"# train_data = \"azureml:math_valid_multi_arith:1\""
"# valid_data = \"azureml:math_valid_multi_arith:1\""
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@
"valid_data = Input(type=AssetTypes.URI_FILE, path=valid_data_path)\n",
"\n",
"# If validation data was registered to workspace already, navigate to the Data tab, select the data to use and use the 'Named asset URI'\n",
"# train_data = \"azureml:nli_train:1\""
"# valid_data = \"azureml:nli_validation:1\""
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@
"valid_data = Input(type=AssetTypes.URI_FILE, path=valid_data_path)\n",
"\n",
"# If validation data was registered to workspace already, navigate to the Data tab, select the data to use and use the 'Named asset URI'\n",
"# train_data = \"azureml:nlu_qa_valid_cqna:1\""
"# valid_data = \"azureml:nlu_qa_valid_cqna:1\""
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@
"valid_data = Input(type=AssetTypes.URI_FILE, path=valid_data_path)\n",
"\n",
"# If validation data was registered to workspace already, navigate to the Data tab, select the data to use and use the 'Named asset URI'\n",
"# train_data = \"azureml:nlu_qa_valid_aqua_rat:1\""
"# valid_data = \"azureml:nlu_qa_valid_aqua_rat:1\""
]
},
{
Expand Down
Loading

0 comments on commit 0c0b38f

Please sign in to comment.