Skip to content

Commit

Permalink
Add distillation component (#3188)
Browse files Browse the repository at this point in the history
* Add distillation component

* throw exception, update copyrights

* fix flakes

* fix flakes

* update
  • Loading branch information
novaturient95 authored Jul 23, 2024
1 parent d317ef1 commit e277f58
Show file tree
Hide file tree
Showing 8 changed files with 1,157 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
type: component
spec: spec.yaml
categories: ["Foundational Models", "Finetune", "Distillation"]
117 changes: 117 additions & 0 deletions assets/training/distillation/components/data_generation/spec.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
name: oss_distillation_generate_data
version: 0.0.1
type: command

is_deterministic: True

display_name: OSS Distillation Generate Data
description: Component to generate data from teacher model enpoint

environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/63

inputs:
# Inputs
train_file_path:
type: uri_file
description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount

validation_file_path:
type: uri_file
optional: true
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount

teacher_model_endpoint_name:
type: string
optional: true
description: Teacher model endpoint name

teacher_model_endpoint_url:
type: string
optional: true
description: Teacher model endpoint URL

teacher_model_endpoint_key:
type: string
optional: true
description: Teacher model endpoint key

teacher_model_max_new_tokens:
type: integer
default: 128
description: Teacher model max_new_tokens inference parameter

teacher_model_temperature:
type: number
default: 0.2
description: Teacher model temperature inference parameter

teacher_model_top_p:
type: number
default: 0.1
description: Teacher model top_p inference parameter

teacher_model_frequency_penalty:
type: number
default: 0.0
description: Teacher model frequency penalty inference parameter

teacher_model_presence_penalty:
type: number
default: 0.0
description: Teacher model presence penalty inference parameter

teacher_model_stop:
type: string
optional: true
description: Teacher model stop inference parameter

request_batch_size:
type: integer
default: 10
description: No of data records to hit teacher model endpoint in one go

min_endpoint_success_ratio:
type: number
default: 0.7
description: >
The minimum value of (successful_requests / total_requests) required for classifying inference as successful.
If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed.
By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.)
enable_chain_of_thought:
type: string
default: "true"
description: Enable Chain of thought for data generation

outputs:
generated_train_file_path:
type: uri_file
description: Generated train data
mode: rw_mount
generated_validation_file_path:
type: uri_file
description: Generated validation data
mode: rw_mount

code: src/
command: >-
python generate_data.py
--train_file_path ${{inputs.train_file_path}}
$[[--validation_file_path ${{inputs.validation_file_path}}]]
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]]
$[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]]
$[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]]
--teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}}
--teacher_model_temperature ${{inputs.teacher_model_temperature}}
--teacher_model_top_p ${{inputs.teacher_model_top_p}}
--teacher_model_frequency_penalty ${{inputs.teacher_model_frequency_penalty}}
--teacher_model_presence_penalty ${{inputs.teacher_model_presence_penalty}}
$[[--teacher_model_stop ${{inputs.teacher_model_stop}}]]
--request_batch_size ${{inputs.request_batch_size}}
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}}
--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}
--generated_train_file_path ${{outputs.generated_train_file_path}}
--generated_validation_file_path ${{outputs.generated_validation_file_path}}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Data generator common module init."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Data generatior constants."""


# COMPONENT META
COMPONENT_NAME = "oss_distillation_generate_data"

# DATA GENERATOR VALIDATION
SUPPORTED_FILE_FORMATS = [".jsonl"]
TRAIN_FILE_NAME = "train_input.jsonl"
VALIDATION_FILE_NAME = "validation_input.jsonl"

# Scoring paths
VLLM_CHAT_SCORE_PATH = "/v1/chat/completions"
HFTV2_TEXT_GEN_SCORE_PATH = "/score"

# DATA GEN REQUEST
DEFAULT_SUCCESS_RATIO = 0.7
DEFAULT_REQUEST_BATCH_SIZE = 10
MAX_BATCH_SIZE = 100

# VLLM INFERENCE KEYS
TOP_P = "top_p"
MAX_TOKENS = "max_tokens"
MAX_NEW_TOKENS = "max_new_tokens"
TEMPERATURE = "temperature"
FREQUENCY_PENALTY = "frequency_penalty"
PRESENCE_PENALTY = "presence_penalty"
STOP_TOKEN = "stop"

# TEACHER MODEL DEFAULT INFERENCE PARAMS
DEFAULT_MAX_NEW_TOKENS = 128
DEFAULT_TOP_P = 0.1
DEFAULT_TEMPERATURE = 0.2

# CHAIN OF THOUGHT (COT)
COT_SYSTEM_PROMPT = (
"You are a helpful assistant. "
"Write out in a step by step manner your reasoning about the answer using no more than 80 words. "
"Based on the reasoning, produce the final answer. "
"Your response should be in JSON format without using any backticks. "
"The JSON is a dictionary whose keys are 'reason' and 'answer_choice'."
)


class InferenceMode:
"""Supported inference modes."""

HFTV2_CHAT_COMPLETION = "hftv2_chat_completion"
HFTV2_TEXT_GENERATION = "hftv2_text_generation"
VLLM_CHAT_COMPLETION = "vllm_chat_completion"
VLLM_TEXT_GENERATION = "vllm_text_generation"
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Data generator utils."""

import os
import time

from azure.ai.ml import MLClient
from azure.ai.ml.identity import AzureMLOnBehalfOfCredential
from azure.identity import AzureCliCredential, ManagedIdentityCredential
from azureml.acft.common_components import get_logger_app
from azureml.core import Run, Workspace
from azureml.core.run import _OfflineRun

from typing import Union


logger = get_logger_app("azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import")
RETRY_DELAY = 5


def retry(times: int):
"""Retry utility to wrap.
Args:
times (int): No of retries
"""

def decorator(func):
def newfn(*args, **kwargs):
attempt = 1
while attempt <= times:
try:
return func(*args, **kwargs)
except Exception:
attempt += 1
ex_msg = "Exception thrown when attempting to run {}, attempt {} of {}".format(
func.__name__, attempt, times
)
logger.warning(ex_msg)
if attempt < times:
time.sleep(RETRY_DELAY)
else:
logger.warning(
"Retried {} times when calling {}, now giving up!".format(times, func.__name__)
)
raise

return newfn

return decorator


def get_credential() -> Union[ManagedIdentityCredential, AzureMLOnBehalfOfCredential]:
"""Create and validate credentials."""
# try msi, followed by obo, followed by azure cli
credential = None
try:
msi_client_id = os.environ.get("DEFAULT_IDENTITY_CLIENT_ID")
credential = ManagedIdentityCredential(client_id=msi_client_id)
credential.get_token("https://management.azure.com/.default")
logger.info("Using MSI creds")
return credential
except Exception:
logger.error("MSI auth failed")
try:
credential = AzureMLOnBehalfOfCredential()
credential.get_token("https://management.azure.com/.default")
logger.info("Using OBO creds")
return credential
except Exception:
logger.error("OBO cred failed")
try:
credential = AzureCliCredential()
credential.get_token("https://management.azure.com/.default")
logger.info("Using OBO creds")
return credential
except Exception:
logger.error("Azure CLI cred failed")

raise Exception("Error creating credentials.")


def get_workspace() -> Workspace:
"""Return current workspace."""
run = Run.get_context()
if isinstance(run, _OfflineRun):
ws: Workspace = Workspace.from_config("config.json")
else:
ws: Workspace = run.experiment.workspace
return ws


def get_workspace_mlclient(workspace: Workspace = None) -> MLClient:
"""Return workspace mlclient."""
credential = get_credential()
workspace = get_workspace() if workspace is None else workspace
if credential and workspace:
return MLClient(
credential,
subscription_id=workspace.subscription_id,
resource_group_name=workspace.resource_group,
workspace_name=workspace.name
)
raise Exception("Error creating MLClient. No credentials or workspace found")


def get_online_endpoint_key(mlclient_ws: MLClient, endpoint_name: str) -> str:
"""Return online endpoint primary key."""
try:
keys = mlclient_ws.online_endpoints.get_keys(endpoint_name)
return keys.primary_key
except Exception as e:
logger.error(f"Exception in fetching online endpoint keys for endpoint name: {endpoint_name}. Error {e}")
return None


def get_online_endpoint_url(mlclient_ws: MLClient, endpoint_name: str) -> str:
"""Return online endpoint URL for an endpoint name."""
try:
endpoint = mlclient_ws.online_endpoints.get(endpoint_name)
return endpoint.scoring_uri
except Exception as e:
logger.error(
f"Exception in fetching online endpoint scoring URL for endpoint name: {endpoint_name}. Error {e}")
return None


def get_serverless_endpoint_key(mlclient_ws: MLClient, endpoint_name: str) -> str:
"""Return serverless endpoint primary key."""
try:
keys = mlclient_ws.serverless_endpoints.get_keys(endpoint_name)
return keys.primary_key
except Exception as e:
logger.error(f"Exception in fetching serverless endpoint keys for endpoint name: {endpoint_name}. Error {e}")
return None


def get_serverless_endpoint_url(mlclient_ws: MLClient, endpoint_name: str) -> str:
"""Return serverless endpoint URL for an endpoint name."""
try:
endpoint = mlclient_ws.serverless_endpoints.get(endpoint_name)
return endpoint.scoring_uri
except Exception as e:
logger.error(
f"Exception in fetching serverless endpoint scoring URL for endpoint name: {endpoint_name}. Error {e}")
return None
Loading

0 comments on commit e277f58

Please sign in to comment.