-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add distillation component * throw exception, update copyrights * fix flakes * fix flakes * update
- Loading branch information
1 parent
d317ef1
commit e277f58
Showing
8 changed files
with
1,157 additions
and
0 deletions.
There are no files selected for viewing
3 changes: 3 additions & 0 deletions
3
assets/training/distillation/components/data_generation/asset.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
type: component | ||
spec: spec.yaml | ||
categories: ["Foundational Models", "Finetune", "Distillation"] |
117 changes: 117 additions & 0 deletions
117
assets/training/distillation/components/data_generation/spec.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json | ||
name: oss_distillation_generate_data | ||
version: 0.0.1 | ||
type: command | ||
|
||
is_deterministic: True | ||
|
||
display_name: OSS Distillation Generate Data | ||
description: Component to generate data from teacher model enpoint | ||
|
||
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/63 | ||
|
||
inputs: | ||
# Inputs | ||
train_file_path: | ||
type: uri_file | ||
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`. | ||
mode: rw_mount | ||
|
||
validation_file_path: | ||
type: uri_file | ||
optional: true | ||
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`. | ||
mode: rw_mount | ||
|
||
teacher_model_endpoint_name: | ||
type: string | ||
optional: true | ||
description: Teacher model endpoint name | ||
|
||
teacher_model_endpoint_url: | ||
type: string | ||
optional: true | ||
description: Teacher model endpoint URL | ||
|
||
teacher_model_endpoint_key: | ||
type: string | ||
optional: true | ||
description: Teacher model endpoint key | ||
|
||
teacher_model_max_new_tokens: | ||
type: integer | ||
default: 128 | ||
description: Teacher model max_new_tokens inference parameter | ||
|
||
teacher_model_temperature: | ||
type: number | ||
default: 0.2 | ||
description: Teacher model temperature inference parameter | ||
|
||
teacher_model_top_p: | ||
type: number | ||
default: 0.1 | ||
description: Teacher model top_p inference parameter | ||
|
||
teacher_model_frequency_penalty: | ||
type: number | ||
default: 0.0 | ||
description: Teacher model frequency penalty inference parameter | ||
|
||
teacher_model_presence_penalty: | ||
type: number | ||
default: 0.0 | ||
description: Teacher model presence penalty inference parameter | ||
|
||
teacher_model_stop: | ||
type: string | ||
optional: true | ||
description: Teacher model stop inference parameter | ||
|
||
request_batch_size: | ||
type: integer | ||
default: 10 | ||
description: No of data records to hit teacher model endpoint in one go | ||
|
||
min_endpoint_success_ratio: | ||
type: number | ||
default: 0.7 | ||
description: > | ||
The minimum value of (successful_requests / total_requests) required for classifying inference as successful. | ||
If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed. | ||
By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.) | ||
enable_chain_of_thought: | ||
type: string | ||
default: "true" | ||
description: Enable Chain of thought for data generation | ||
|
||
outputs: | ||
generated_train_file_path: | ||
type: uri_file | ||
description: Generated train data | ||
mode: rw_mount | ||
generated_validation_file_path: | ||
type: uri_file | ||
description: Generated validation data | ||
mode: rw_mount | ||
|
||
code: src/ | ||
command: >- | ||
python generate_data.py | ||
--train_file_path ${{inputs.train_file_path}} | ||
$[[--validation_file_path ${{inputs.validation_file_path}}]] | ||
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]] | ||
$[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]] | ||
$[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]] | ||
--teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}} | ||
--teacher_model_temperature ${{inputs.teacher_model_temperature}} | ||
--teacher_model_top_p ${{inputs.teacher_model_top_p}} | ||
--teacher_model_frequency_penalty ${{inputs.teacher_model_frequency_penalty}} | ||
--teacher_model_presence_penalty ${{inputs.teacher_model_presence_penalty}} | ||
$[[--teacher_model_stop ${{inputs.teacher_model_stop}}]] | ||
--request_batch_size ${{inputs.request_batch_size}} | ||
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}} | ||
--enable_chain_of_thought ${{inputs.enable_chain_of_thought}} | ||
--generated_train_file_path ${{outputs.generated_train_file_path}} | ||
--generated_validation_file_path ${{outputs.generated_validation_file_path}} |
4 changes: 4 additions & 0 deletions
4
assets/training/distillation/components/data_generation/src/common/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
"""Data generator common module init.""" |
54 changes: 54 additions & 0 deletions
54
assets/training/distillation/components/data_generation/src/common/constants.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
"""Data generatior constants.""" | ||
|
||
|
||
# COMPONENT META | ||
COMPONENT_NAME = "oss_distillation_generate_data" | ||
|
||
# DATA GENERATOR VALIDATION | ||
SUPPORTED_FILE_FORMATS = [".jsonl"] | ||
TRAIN_FILE_NAME = "train_input.jsonl" | ||
VALIDATION_FILE_NAME = "validation_input.jsonl" | ||
|
||
# Scoring paths | ||
VLLM_CHAT_SCORE_PATH = "/v1/chat/completions" | ||
HFTV2_TEXT_GEN_SCORE_PATH = "/score" | ||
|
||
# DATA GEN REQUEST | ||
DEFAULT_SUCCESS_RATIO = 0.7 | ||
DEFAULT_REQUEST_BATCH_SIZE = 10 | ||
MAX_BATCH_SIZE = 100 | ||
|
||
# VLLM INFERENCE KEYS | ||
TOP_P = "top_p" | ||
MAX_TOKENS = "max_tokens" | ||
MAX_NEW_TOKENS = "max_new_tokens" | ||
TEMPERATURE = "temperature" | ||
FREQUENCY_PENALTY = "frequency_penalty" | ||
PRESENCE_PENALTY = "presence_penalty" | ||
STOP_TOKEN = "stop" | ||
|
||
# TEACHER MODEL DEFAULT INFERENCE PARAMS | ||
DEFAULT_MAX_NEW_TOKENS = 128 | ||
DEFAULT_TOP_P = 0.1 | ||
DEFAULT_TEMPERATURE = 0.2 | ||
|
||
# CHAIN OF THOUGHT (COT) | ||
COT_SYSTEM_PROMPT = ( | ||
"You are a helpful assistant. " | ||
"Write out in a step by step manner your reasoning about the answer using no more than 80 words. " | ||
"Based on the reasoning, produce the final answer. " | ||
"Your response should be in JSON format without using any backticks. " | ||
"The JSON is a dictionary whose keys are 'reason' and 'answer_choice'." | ||
) | ||
|
||
|
||
class InferenceMode: | ||
"""Supported inference modes.""" | ||
|
||
HFTV2_CHAT_COMPLETION = "hftv2_chat_completion" | ||
HFTV2_TEXT_GENERATION = "hftv2_text_generation" | ||
VLLM_CHAT_COMPLETION = "vllm_chat_completion" | ||
VLLM_TEXT_GENERATION = "vllm_text_generation" |
148 changes: 148 additions & 0 deletions
148
assets/training/distillation/components/data_generation/src/common/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
"""Data generator utils.""" | ||
|
||
import os | ||
import time | ||
|
||
from azure.ai.ml import MLClient | ||
from azure.ai.ml.identity import AzureMLOnBehalfOfCredential | ||
from azure.identity import AzureCliCredential, ManagedIdentityCredential | ||
from azureml.acft.common_components import get_logger_app | ||
from azureml.core import Run, Workspace | ||
from azureml.core.run import _OfflineRun | ||
|
||
from typing import Union | ||
|
||
|
||
logger = get_logger_app("azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import") | ||
RETRY_DELAY = 5 | ||
|
||
|
||
def retry(times: int): | ||
"""Retry utility to wrap. | ||
Args: | ||
times (int): No of retries | ||
""" | ||
|
||
def decorator(func): | ||
def newfn(*args, **kwargs): | ||
attempt = 1 | ||
while attempt <= times: | ||
try: | ||
return func(*args, **kwargs) | ||
except Exception: | ||
attempt += 1 | ||
ex_msg = "Exception thrown when attempting to run {}, attempt {} of {}".format( | ||
func.__name__, attempt, times | ||
) | ||
logger.warning(ex_msg) | ||
if attempt < times: | ||
time.sleep(RETRY_DELAY) | ||
else: | ||
logger.warning( | ||
"Retried {} times when calling {}, now giving up!".format(times, func.__name__) | ||
) | ||
raise | ||
|
||
return newfn | ||
|
||
return decorator | ||
|
||
|
||
def get_credential() -> Union[ManagedIdentityCredential, AzureMLOnBehalfOfCredential]: | ||
"""Create and validate credentials.""" | ||
# try msi, followed by obo, followed by azure cli | ||
credential = None | ||
try: | ||
msi_client_id = os.environ.get("DEFAULT_IDENTITY_CLIENT_ID") | ||
credential = ManagedIdentityCredential(client_id=msi_client_id) | ||
credential.get_token("https://management.azure.com/.default") | ||
logger.info("Using MSI creds") | ||
return credential | ||
except Exception: | ||
logger.error("MSI auth failed") | ||
try: | ||
credential = AzureMLOnBehalfOfCredential() | ||
credential.get_token("https://management.azure.com/.default") | ||
logger.info("Using OBO creds") | ||
return credential | ||
except Exception: | ||
logger.error("OBO cred failed") | ||
try: | ||
credential = AzureCliCredential() | ||
credential.get_token("https://management.azure.com/.default") | ||
logger.info("Using OBO creds") | ||
return credential | ||
except Exception: | ||
logger.error("Azure CLI cred failed") | ||
|
||
raise Exception("Error creating credentials.") | ||
|
||
|
||
def get_workspace() -> Workspace: | ||
"""Return current workspace.""" | ||
run = Run.get_context() | ||
if isinstance(run, _OfflineRun): | ||
ws: Workspace = Workspace.from_config("config.json") | ||
else: | ||
ws: Workspace = run.experiment.workspace | ||
return ws | ||
|
||
|
||
def get_workspace_mlclient(workspace: Workspace = None) -> MLClient: | ||
"""Return workspace mlclient.""" | ||
credential = get_credential() | ||
workspace = get_workspace() if workspace is None else workspace | ||
if credential and workspace: | ||
return MLClient( | ||
credential, | ||
subscription_id=workspace.subscription_id, | ||
resource_group_name=workspace.resource_group, | ||
workspace_name=workspace.name | ||
) | ||
raise Exception("Error creating MLClient. No credentials or workspace found") | ||
|
||
|
||
def get_online_endpoint_key(mlclient_ws: MLClient, endpoint_name: str) -> str: | ||
"""Return online endpoint primary key.""" | ||
try: | ||
keys = mlclient_ws.online_endpoints.get_keys(endpoint_name) | ||
return keys.primary_key | ||
except Exception as e: | ||
logger.error(f"Exception in fetching online endpoint keys for endpoint name: {endpoint_name}. Error {e}") | ||
return None | ||
|
||
|
||
def get_online_endpoint_url(mlclient_ws: MLClient, endpoint_name: str) -> str: | ||
"""Return online endpoint URL for an endpoint name.""" | ||
try: | ||
endpoint = mlclient_ws.online_endpoints.get(endpoint_name) | ||
return endpoint.scoring_uri | ||
except Exception as e: | ||
logger.error( | ||
f"Exception in fetching online endpoint scoring URL for endpoint name: {endpoint_name}. Error {e}") | ||
return None | ||
|
||
|
||
def get_serverless_endpoint_key(mlclient_ws: MLClient, endpoint_name: str) -> str: | ||
"""Return serverless endpoint primary key.""" | ||
try: | ||
keys = mlclient_ws.serverless_endpoints.get_keys(endpoint_name) | ||
return keys.primary_key | ||
except Exception as e: | ||
logger.error(f"Exception in fetching serverless endpoint keys for endpoint name: {endpoint_name}. Error {e}") | ||
return None | ||
|
||
|
||
def get_serverless_endpoint_url(mlclient_ws: MLClient, endpoint_name: str) -> str: | ||
"""Return serverless endpoint URL for an endpoint name.""" | ||
try: | ||
endpoint = mlclient_ws.serverless_endpoints.get(endpoint_name) | ||
return endpoint.scoring_uri | ||
except Exception as e: | ||
logger.error( | ||
f"Exception in fetching serverless endpoint scoring URL for endpoint name: {endpoint_name}. Error {e}") | ||
return None |
Oops, something went wrong.