From e277f58f5912f19d111673fba04de6cc4ad69bb8 Mon Sep 17 00:00:00 2001 From: Ayush Mishra <61145377+novaturient95@users.noreply.github.com> Date: Tue, 23 Jul 2024 20:16:43 +0530 Subject: [PATCH] Add distillation component (#3188) * Add distillation component * throw exception, update copyrights * fix flakes * fix flakes * update --- .../components/data_generation/asset.yaml | 3 + .../components/data_generation/spec.yaml | 117 ++++ .../data_generation/src/common/__init__.py | 4 + .../data_generation/src/common/constants.py | 54 ++ .../data_generation/src/common/utils.py | 148 +++++ .../data_generation/src/generate_data.py | 517 ++++++++++++++++++ .../components/pipeline/asset.yaml | 3 + .../components/pipeline/spec.yaml | 311 +++++++++++ 8 files changed, 1157 insertions(+) create mode 100644 assets/training/distillation/components/data_generation/asset.yaml create mode 100644 assets/training/distillation/components/data_generation/spec.yaml create mode 100644 assets/training/distillation/components/data_generation/src/common/__init__.py create mode 100644 assets/training/distillation/components/data_generation/src/common/constants.py create mode 100644 assets/training/distillation/components/data_generation/src/common/utils.py create mode 100644 assets/training/distillation/components/data_generation/src/generate_data.py create mode 100644 assets/training/distillation/components/pipeline/asset.yaml create mode 100644 assets/training/distillation/components/pipeline/spec.yaml diff --git a/assets/training/distillation/components/data_generation/asset.yaml b/assets/training/distillation/components/data_generation/asset.yaml new file mode 100644 index 0000000000..d0767a9360 --- /dev/null +++ b/assets/training/distillation/components/data_generation/asset.yaml @@ -0,0 +1,3 @@ +type: component +spec: spec.yaml +categories: ["Foundational Models", "Finetune", "Distillation"] diff --git a/assets/training/distillation/components/data_generation/spec.yaml b/assets/training/distillation/components/data_generation/spec.yaml new file mode 100644 index 0000000000..dd5e49a7a3 --- /dev/null +++ b/assets/training/distillation/components/data_generation/spec.yaml @@ -0,0 +1,117 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json +name: oss_distillation_generate_data +version: 0.0.1 +type: command + +is_deterministic: True + +display_name: OSS Distillation Generate Data +description: Component to generate data from teacher model enpoint + +environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/63 + +inputs: + # Inputs + train_file_path: + type: uri_file + description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`. + mode: rw_mount + + validation_file_path: + type: uri_file + optional: true + description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`. + mode: rw_mount + + teacher_model_endpoint_name: + type: string + optional: true + description: Teacher model endpoint name + + teacher_model_endpoint_url: + type: string + optional: true + description: Teacher model endpoint URL + + teacher_model_endpoint_key: + type: string + optional: true + description: Teacher model endpoint key + + teacher_model_max_new_tokens: + type: integer + default: 128 + description: Teacher model max_new_tokens inference parameter + + teacher_model_temperature: + type: number + default: 0.2 + description: Teacher model temperature inference parameter + + teacher_model_top_p: + type: number + default: 0.1 + description: Teacher model top_p inference parameter + + teacher_model_frequency_penalty: + type: number + default: 0.0 + description: Teacher model frequency penalty inference parameter + + teacher_model_presence_penalty: + type: number + default: 0.0 + description: Teacher model presence penalty inference parameter + + teacher_model_stop: + type: string + optional: true + description: Teacher model stop inference parameter + + request_batch_size: + type: integer + default: 10 + description: No of data records to hit teacher model endpoint in one go + + min_endpoint_success_ratio: + type: number + default: 0.7 + description: > + The minimum value of (successful_requests / total_requests) required for classifying inference as successful. + If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed. + By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.) + + enable_chain_of_thought: + type: string + default: "true" + description: Enable Chain of thought for data generation + +outputs: + generated_train_file_path: + type: uri_file + description: Generated train data + mode: rw_mount + generated_validation_file_path: + type: uri_file + description: Generated validation data + mode: rw_mount + +code: src/ +command: >- + python generate_data.py + --train_file_path ${{inputs.train_file_path}} + $[[--validation_file_path ${{inputs.validation_file_path}}]] + $[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]] + $[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]] + $[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]] + --teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}} + --teacher_model_temperature ${{inputs.teacher_model_temperature}} + --teacher_model_top_p ${{inputs.teacher_model_top_p}} + --teacher_model_frequency_penalty ${{inputs.teacher_model_frequency_penalty}} + --teacher_model_presence_penalty ${{inputs.teacher_model_presence_penalty}} + $[[--teacher_model_stop ${{inputs.teacher_model_stop}}]] + --request_batch_size ${{inputs.request_batch_size}} + --min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}} + --enable_chain_of_thought ${{inputs.enable_chain_of_thought}} + --generated_train_file_path ${{outputs.generated_train_file_path}} + --generated_validation_file_path ${{outputs.generated_validation_file_path}} diff --git a/assets/training/distillation/components/data_generation/src/common/__init__.py b/assets/training/distillation/components/data_generation/src/common/__init__.py new file mode 100644 index 0000000000..fad5202cd2 --- /dev/null +++ b/assets/training/distillation/components/data_generation/src/common/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Data generator common module init.""" diff --git a/assets/training/distillation/components/data_generation/src/common/constants.py b/assets/training/distillation/components/data_generation/src/common/constants.py new file mode 100644 index 0000000000..c0d7e6724d --- /dev/null +++ b/assets/training/distillation/components/data_generation/src/common/constants.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Data generatior constants.""" + + +# COMPONENT META +COMPONENT_NAME = "oss_distillation_generate_data" + +# DATA GENERATOR VALIDATION +SUPPORTED_FILE_FORMATS = [".jsonl"] +TRAIN_FILE_NAME = "train_input.jsonl" +VALIDATION_FILE_NAME = "validation_input.jsonl" + +# Scoring paths +VLLM_CHAT_SCORE_PATH = "/v1/chat/completions" +HFTV2_TEXT_GEN_SCORE_PATH = "/score" + +# DATA GEN REQUEST +DEFAULT_SUCCESS_RATIO = 0.7 +DEFAULT_REQUEST_BATCH_SIZE = 10 +MAX_BATCH_SIZE = 100 + +# VLLM INFERENCE KEYS +TOP_P = "top_p" +MAX_TOKENS = "max_tokens" +MAX_NEW_TOKENS = "max_new_tokens" +TEMPERATURE = "temperature" +FREQUENCY_PENALTY = "frequency_penalty" +PRESENCE_PENALTY = "presence_penalty" +STOP_TOKEN = "stop" + +# TEACHER MODEL DEFAULT INFERENCE PARAMS +DEFAULT_MAX_NEW_TOKENS = 128 +DEFAULT_TOP_P = 0.1 +DEFAULT_TEMPERATURE = 0.2 + +# CHAIN OF THOUGHT (COT) +COT_SYSTEM_PROMPT = ( + "You are a helpful assistant. " + "Write out in a step by step manner your reasoning about the answer using no more than 80 words. " + "Based on the reasoning, produce the final answer. " + "Your response should be in JSON format without using any backticks. " + "The JSON is a dictionary whose keys are 'reason' and 'answer_choice'." +) + + +class InferenceMode: + """Supported inference modes.""" + + HFTV2_CHAT_COMPLETION = "hftv2_chat_completion" + HFTV2_TEXT_GENERATION = "hftv2_text_generation" + VLLM_CHAT_COMPLETION = "vllm_chat_completion" + VLLM_TEXT_GENERATION = "vllm_text_generation" diff --git a/assets/training/distillation/components/data_generation/src/common/utils.py b/assets/training/distillation/components/data_generation/src/common/utils.py new file mode 100644 index 0000000000..3af4a15891 --- /dev/null +++ b/assets/training/distillation/components/data_generation/src/common/utils.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Data generator utils.""" + +import os +import time + +from azure.ai.ml import MLClient +from azure.ai.ml.identity import AzureMLOnBehalfOfCredential +from azure.identity import AzureCliCredential, ManagedIdentityCredential +from azureml.acft.common_components import get_logger_app +from azureml.core import Run, Workspace +from azureml.core.run import _OfflineRun + +from typing import Union + + +logger = get_logger_app("azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import") +RETRY_DELAY = 5 + + +def retry(times: int): + """Retry utility to wrap. + + Args: + times (int): No of retries + """ + + def decorator(func): + def newfn(*args, **kwargs): + attempt = 1 + while attempt <= times: + try: + return func(*args, **kwargs) + except Exception: + attempt += 1 + ex_msg = "Exception thrown when attempting to run {}, attempt {} of {}".format( + func.__name__, attempt, times + ) + logger.warning(ex_msg) + if attempt < times: + time.sleep(RETRY_DELAY) + else: + logger.warning( + "Retried {} times when calling {}, now giving up!".format(times, func.__name__) + ) + raise + + return newfn + + return decorator + + +def get_credential() -> Union[ManagedIdentityCredential, AzureMLOnBehalfOfCredential]: + """Create and validate credentials.""" + # try msi, followed by obo, followed by azure cli + credential = None + try: + msi_client_id = os.environ.get("DEFAULT_IDENTITY_CLIENT_ID") + credential = ManagedIdentityCredential(client_id=msi_client_id) + credential.get_token("https://management.azure.com/.default") + logger.info("Using MSI creds") + return credential + except Exception: + logger.error("MSI auth failed") + try: + credential = AzureMLOnBehalfOfCredential() + credential.get_token("https://management.azure.com/.default") + logger.info("Using OBO creds") + return credential + except Exception: + logger.error("OBO cred failed") + try: + credential = AzureCliCredential() + credential.get_token("https://management.azure.com/.default") + logger.info("Using OBO creds") + return credential + except Exception: + logger.error("Azure CLI cred failed") + + raise Exception("Error creating credentials.") + + +def get_workspace() -> Workspace: + """Return current workspace.""" + run = Run.get_context() + if isinstance(run, _OfflineRun): + ws: Workspace = Workspace.from_config("config.json") + else: + ws: Workspace = run.experiment.workspace + return ws + + +def get_workspace_mlclient(workspace: Workspace = None) -> MLClient: + """Return workspace mlclient.""" + credential = get_credential() + workspace = get_workspace() if workspace is None else workspace + if credential and workspace: + return MLClient( + credential, + subscription_id=workspace.subscription_id, + resource_group_name=workspace.resource_group, + workspace_name=workspace.name + ) + raise Exception("Error creating MLClient. No credentials or workspace found") + + +def get_online_endpoint_key(mlclient_ws: MLClient, endpoint_name: str) -> str: + """Return online endpoint primary key.""" + try: + keys = mlclient_ws.online_endpoints.get_keys(endpoint_name) + return keys.primary_key + except Exception as e: + logger.error(f"Exception in fetching online endpoint keys for endpoint name: {endpoint_name}. Error {e}") + return None + + +def get_online_endpoint_url(mlclient_ws: MLClient, endpoint_name: str) -> str: + """Return online endpoint URL for an endpoint name.""" + try: + endpoint = mlclient_ws.online_endpoints.get(endpoint_name) + return endpoint.scoring_uri + except Exception as e: + logger.error( + f"Exception in fetching online endpoint scoring URL for endpoint name: {endpoint_name}. Error {e}") + return None + + +def get_serverless_endpoint_key(mlclient_ws: MLClient, endpoint_name: str) -> str: + """Return serverless endpoint primary key.""" + try: + keys = mlclient_ws.serverless_endpoints.get_keys(endpoint_name) + return keys.primary_key + except Exception as e: + logger.error(f"Exception in fetching serverless endpoint keys for endpoint name: {endpoint_name}. Error {e}") + return None + + +def get_serverless_endpoint_url(mlclient_ws: MLClient, endpoint_name: str) -> str: + """Return serverless endpoint URL for an endpoint name.""" + try: + endpoint = mlclient_ws.serverless_endpoints.get(endpoint_name) + return endpoint.scoring_uri + except Exception as e: + logger.error( + f"Exception in fetching serverless endpoint scoring URL for endpoint name: {endpoint_name}. Error {e}") + return None diff --git a/assets/training/distillation/components/data_generation/src/generate_data.py b/assets/training/distillation/components/data_generation/src/generate_data.py new file mode 100644 index 0000000000..ace8c441cd --- /dev/null +++ b/assets/training/distillation/components/data_generation/src/generate_data.py @@ -0,0 +1,517 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""File containing function for FTaaS data import component.""" + +import json +import logging +import pandas as pd + +import argparse +import requests +from argparse import Namespace +from requests import Response +from pathlib import Path +from typing import List, Optional + +from azureml.acft.contrib.hf import VERSION, PROJECT_NAME +from azureml.acft.contrib.hf.nlp.constants.constants import LOGS_TO_BE_FILTERED_IN_APPINSIGHTS +from azureml.acft.common_components import get_logger_app, set_logging_parameters, LoggingLiterals +from azureml.acft.common_components.utils.error_handling.exceptions import ACFTValidationException +from azureml.acft.common_components.utils.error_handling.error_definitions import ACFTUserError +from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import ( + swallow_all_exceptions, +) +from azureml._common._error_definition.azureml_error import AzureMLError + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from common.constants import ( + COMPONENT_NAME, + COT_SYSTEM_PROMPT, + DEFAULT_REQUEST_BATCH_SIZE, + DEFAULT_SUCCESS_RATIO, + DEFAULT_MAX_NEW_TOKENS, + DEFAULT_TEMPERATURE, + DEFAULT_TOP_P, + FREQUENCY_PENALTY, + PRESENCE_PENALTY, + MAX_NEW_TOKENS, + MAX_BATCH_SIZE, + TEMPERATURE, + TOP_P, + STOP_TOKEN, + SUPPORTED_FILE_FORMATS, + VLLM_CHAT_SCORE_PATH, +) + +from common.utils import ( + get_workspace_mlclient, + get_online_endpoint_key, + get_online_endpoint_url, + get_serverless_endpoint_key, + get_serverless_endpoint_url, + retry, +) + + +logger = get_logger_app("azureml.acft.contrib.hf.nlp.entry_point.data_import.data_import") + + +def get_parser(): + """ + Add arguments and returns the parser. Here we add all the arguments for all the tasks. + + Those arguments that are not relevant for the input task should be ignored. + """ + parser = argparse.ArgumentParser(description="Model selector for hugging face models", allow_abbrev=False) + + # File I/O + parser.add_argument( + "--train_file_path", + type=str, + help="Input train file path", + ) + + parser.add_argument( + "--validation_file_path", + default=None, + type=str, + help="Input validation file path", + ) + + parser.add_argument( + "--generated_train_file_path", + type=Path, + default=None, + help="file to save the generated training data", + ) + + parser.add_argument( + "--generated_validation_file_path", + type=Path, + default=None, + help="file to save the generated validation data", + ) + + # add optional data-generator params + parser.add_argument( + "--teacher_model_endpoint_name", + type=str, + required=False, + help="Teacher model endpoint name", + ) + parser.add_argument( + "--teacher_model_endpoint_url", + type=str, + required=False, + help="Teacher model endpoint URL", + ) + parser.add_argument( + "--teacher_model_endpoint_key", + type=str, + required=False, + help="Teacher model endpoint key", + ) + parser.add_argument( + "--teacher_model_max_new_tokens", + type=int, + required=False, + default=DEFAULT_MAX_NEW_TOKENS, + help="Teacher model max_tokens parameter", + ) + parser.add_argument( + "--teacher_model_temperature", + type=float, + required=False, + default=DEFAULT_TEMPERATURE, + help="Teacher model temperature parameter", + ) + parser.add_argument( + "--teacher_model_top_p", + type=float, + required=False, + default=DEFAULT_TOP_P, + help="Teacher model top-p parameter" + ) + parser.add_argument( + "--teacher_model_frequency_penalty", + type=float, + required=False, + help="Teacher model frequency parameter" + ) + parser.add_argument( + "--teacher_model_presence_penalty", + type=float, + required=False, + help="Teacher model presense penalty" + ) + parser.add_argument( + "--teacher_model_stop", + type=str, + required=False, + help="Teacher model stop " + ) + parser.add_argument( + "--request_batch_size", + type=int, + default=DEFAULT_REQUEST_BATCH_SIZE, + required=False, + help="No of data records to process at a time.", + ) + parser.add_argument( + "--min_endpoint_success_ratio", + type=float, + required=False, + default=DEFAULT_SUCCESS_RATIO, + help=( + f"The minimum value of " + "(successful_requests / total_requests) required for classifying inference as successful. " + "If (successful_requests / total_requests) < min_endpoint_success_ratio, " + "the experiment will be marked as failed. " + f"By default it is {DEFAULT_SUCCESS_RATIO}. " + "(0 means all requests are allowed to fail while 1 means no request should fail.)" + ) + ) + + parser.add_argument( + "--enable_chain_of_thought", + type=str, + required=False, + default="false", + help="This enables Chain of Thought" + ) + + return parser + + +@retry(3) +def _invoke_endpoint(url: str, key: str, data: dict) -> Response: + """Invoke endpoint with payload data. + + Args: + url (str): Endpoint URL + key (str): Endpoint key + data dict): Payload dictionary + + Returns: + Response: Response from invocation + """ + request_headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {key}" + } + response = requests.post(url, headers=request_headers, data=json.dumps(data)) + return response + + +def _validate_file_paths_with_supported_formats(file_paths: List[Optional[str]]): + """Check if the file path is in the list of supported formats.""" + for file_path in file_paths: + if file_path: + file_suffix = Path(file_path).suffix.lower() + file_ext = file_suffix.split('?')[0] + if file_ext and file_ext not in SUPPORTED_FILE_FORMATS: + raise ACFTValidationException._with_error( + AzureMLError.create( + ACFTUserError, + pii_safe_message=( + f"{file_path} is not in list of supported file formats. " + f"Supported file formats: {SUPPORTED_FILE_FORMATS}" + ) + ) + ) + + +def generate_synthetic_data( + teacher_model_endpoint_url: str, + teacher_model_endpoint_key: str, + inference_params: dict, + request_batch_size: int, + min_endpoint_success_ratio: float, + enable_cot: bool, + generated_train_file_path: Path, + generated_validation_file_path: Path, + train_file_path: Path, + validation_file_path: Path = None, +): + """Generate and save synthentic data under output_dataset. + + Args: + teacher_model_endpoint_url (str): Teacher model endpoint URL + teacher_model_endpoint_key (str): Teacher model endpoint key + inference_params (dict): Inference params to hit endpoint with + request_batch_size (int): Input batch size for processing rows in train and validation dataset + min_endpoint_success_ratio (float): Minimum success ratio below which run will be considered a failure + enable_cot (bool): Enable Chain of Thought processing + output_dataset (Path): Path to output directory + train_file_path (Path): Train JSONL file path + validation_file_path (Path, optional): Validation JSONL file path. Defaults to None. + """ + + def process_request(idx: str, enable_cot: bool, data: dict, url: str, endpoint_key: str) -> dict: + """Process a single request. + + Args: + idx (str): Row index in Input data. + enable_cot (bool): If CoT is enabled + data (dict): Payload dict + url (str): Endpoint URL + endpoint_key (str): key to authenticate endpoint request + + Returns: + dict: result dictionary + """ + try: + logger.info(f"request_data: {repr(data)}") + response: Response = _invoke_endpoint(url=url, key=endpoint_key, data=data) + logger.info(f"response_text: {response.text}") + response_data = response.json() + logger.info(f"JSON response: {response_data}") + + # use jsonpath or regex to capture prediction result + prediction_result = ( + None if response.status_code != 200 + # response content should be structured as below for a successful vllm response + else response_data['choices'][0]["message"]["content"].strip() + ) + + if enable_cot: + # Try loading JSON answer and filter 'answer_choice' + # if JSON loading fails, exception will be caught + # And this specific row would not be part of generated data + prediction_result = json.loads(prediction_result)['answer_choice'] + + return { + "idx": idx, + "status_code": response.status_code, + "text": prediction_result, + "exception": None, + } + except Exception as e: + logger.error(f"idx: {idx}. exception: {e}") + return { + "idx": idx, + "status_code": None, + "text": None, + "exception": e, + } + + def replace_cot_system_message(messages: List[dict]) -> List[dict]: + # Replace the system message without changing the original messages list + cot_system_message = {'role': 'system', 'content': COT_SYSTEM_PROMPT} + return [(cot_system_message if message['role'] == 'system' else message) for message in messages] + + def batch_process_data(input_file_path: Path, output_file_path: Path, batch_size: int) -> None: + """Batch process data and do a bulk request to teacher model endpoint. + + Args: + input_file_path (Path): Input data file path + output_file_path (Path): Path to output directory + batch_size (int): Input batch size for processing rows in train and validation dataset + + Raises: + Exception: if success ratio is less than min_endpoint_success_ratio + """ + train_df = pd.read_json(input_file_path, lines=True, chunksize=batch_size) + total_rows = 0 + error_count = 0 + output_data = [] + error_map = {} + ERROR = "error" + + for batch in train_df: + total_rows += len(batch) + futures = [] + + with ThreadPoolExecutor() as executor: + for idx, row in batch.iterrows(): + messages = row.iloc[0] + messages = replace_cot_system_message(messages) if enable_cot else messages + request_data = { + "messages": messages, + **inference_params, + } + + futures.append( + executor.submit( + process_request, + idx, + enable_cot, + request_data, + teacher_model_endpoint_url, + teacher_model_endpoint_key + ) + ) + + # wait for results to complete + future_results = { + result["idx"]: result + for result in [future.result() for future in as_completed(futures)] + } + + idx = 0 + for idx, row in batch.iterrows(): + future_result = future_results.get(idx) + logger.info(future_result) + if future_result['exception']: + logger.error(f"row {idx} failed with exception: {future_result['exception']}") + error_map[ERROR] = error_map.get(ERROR, 0) + 1 + elif future_result['status_code'] != 200: + logger.warning(f"row {idx} request status_code: {future_result['status_code']} != 200") + error_map[future_result['status_code']] = error_map.get(future_result['status_code'], 0) + 1 + else: + new_row = row.copy().iloc[0] + answer = future_result['text'] + + new_row.append( + { + "role": "assistant", + "content": answer, + } + ) + output_data.append({"messages": new_row}) + + Path(output_file_path.parent).mkdir(exist_ok=True, parents=True) + with open(output_file_path, 'w') as f: + for entry in output_data: + f.write(json.dumps(entry) + '\n') + + if error_map: + logger.info("Error summary. With key donating non-200 status code or some other error.") + for k, v in error_map.items(): + error_count += v + logger.warning(f"{k} => {v}") + + success_ratio = float(total_rows - error_count) / total_rows + logger.info(f"Success rate was {success_ratio} for {input_file_path}") + if success_ratio < min_endpoint_success_ratio: + msg = f"Success ratio for dataset {input_file_path}: {success_ratio} < {min_endpoint_success_ratio}." + raise Exception(msg) + + logger.info("Processing train file") + batch_process_data(train_file_path, generated_train_file_path, request_batch_size) + logger.info("Data generated and saved for train file") + + if validation_file_path: + logger.info("Processing validation file") + batch_process_data(validation_file_path, generated_validation_file_path, request_batch_size) + logger.info("Data generated and saved for validation file") + + +def data_import(args: Namespace): + """Copy the user data to output dir.""" + train_file_path = args.train_file_path + validation_file_path = args.validation_file_path + generated_train_file_path = args.generated_train_file_path + generated_validation_file_path = args.generated_validation_file_path + + # add optional data-generator params + teacher_model_endpoint_name = args.teacher_model_endpoint_name + teacher_model_endpoint_url = args.teacher_model_endpoint_url + teacher_model_endpoint_key = args.teacher_model_endpoint_key + teacher_model_max_new_tokens = args.teacher_model_max_new_tokens + teacher_model_temperature = args.teacher_model_temperature + teacher_model_top_p = args.teacher_model_top_p + teacher_model_frequency_penalty = args.teacher_model_frequency_penalty + teacher_model_presence_penalty = args.teacher_model_presence_penalty + teacher_model_stop = args.teacher_model_stop + request_batch_size = args.request_batch_size + min_endpoint_success_ratio = args.min_endpoint_success_ratio + enable_cot_str = args.enable_chain_of_thought + + # validate file formats + _validate_file_paths_with_supported_formats([args.train_file_path, args.validation_file_path]) + logger.info("File format validation successful.") + + enable_cot = True if enable_cot_str.lower() == "true" else False + mlclient_ws = get_workspace_mlclient() + + if teacher_model_endpoint_url is None: + if teacher_model_endpoint_name: + if mlclient_ws: + teacher_model_endpoint_url = get_serverless_endpoint_url(mlclient_ws, teacher_model_endpoint_name)\ + or get_online_endpoint_url(mlclient_ws, teacher_model_endpoint_name) + if not teacher_model_endpoint_url: + raise Exception("Endpoint URL is a requried parameter for data generation") + + if teacher_model_endpoint_key is None: + if teacher_model_endpoint_name: + if mlclient_ws: + teacher_model_endpoint_key = get_serverless_endpoint_key(mlclient_ws, teacher_model_endpoint_name)\ + or get_online_endpoint_key(mlclient_ws, teacher_model_endpoint_name) + if not teacher_model_endpoint_key: + raise Exception("Endpoint key is a requried parameter for data generation") + + if teacher_model_top_p < 0 or teacher_model_top_p > 1: + raise Exception( + f"Invalid teacher_model_top_p. Value should be 0<=val<=1, but it is {teacher_model_top_p}") + if teacher_model_temperature < 0 or teacher_model_temperature > 1: + raise Exception( + f"Invalid teacher_model_temperature. Value should be 0<=val<=1, but it is {teacher_model_temperature}") + if min_endpoint_success_ratio < 0 or min_endpoint_success_ratio > 1: + raise Exception( + f"Invalid min_endpoint_success_ratio. Value should be 0<=val<=1, but it is {min_endpoint_success_ratio}") + + if request_batch_size <= 0 or request_batch_size > MAX_BATCH_SIZE: + raise Exception( + f"Invalid request_batch_size. Value should be 0<=val<={MAX_BATCH_SIZE}, but it is {request_batch_size}") + + inference_params = { + MAX_NEW_TOKENS: teacher_model_max_new_tokens, + TEMPERATURE: teacher_model_temperature, + TOP_P: teacher_model_top_p + } + + if teacher_model_frequency_penalty: + inference_params[FREQUENCY_PENALTY] = teacher_model_frequency_penalty + + if teacher_model_presence_penalty: + inference_params[PRESENCE_PENALTY] = teacher_model_presence_penalty + + if teacher_model_stop: + inference_params[STOP_TOKEN] = teacher_model_stop + + if VLLM_CHAT_SCORE_PATH not in teacher_model_endpoint_url: + teacher_model_endpoint_url += VLLM_CHAT_SCORE_PATH + + logger.info(f"Teacher Endpoint : {teacher_model_endpoint_url}") + + logger.info("Running data generation") + generate_synthetic_data( + teacher_model_endpoint_url=teacher_model_endpoint_url, + teacher_model_endpoint_key=teacher_model_endpoint_key, + inference_params=inference_params, + request_batch_size=request_batch_size, + min_endpoint_success_ratio=min_endpoint_success_ratio, + enable_cot=enable_cot, + generated_train_file_path=generated_train_file_path, + generated_validation_file_path=generated_validation_file_path, + train_file_path=train_file_path, + validation_file_path=validation_file_path, + ) + + +@swallow_all_exceptions(time_delay=5) +def main(): + """Parse args and import model.""" + parser = get_parser() + args, _ = parser.parse_known_args() + + set_logging_parameters( + task_type="ChatCompletion", + acft_custom_dimensions={ + LoggingLiterals.PROJECT_NAME: PROJECT_NAME, + LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION, + LoggingLiterals.COMPONENT_NAME: COMPONENT_NAME + }, + azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS, + log_level=logging.INFO, + ) + + logger.info(args) + data_import(args) + + +if __name__ == "__main__": + main() diff --git a/assets/training/distillation/components/pipeline/asset.yaml b/assets/training/distillation/components/pipeline/asset.yaml new file mode 100644 index 0000000000..d0767a9360 --- /dev/null +++ b/assets/training/distillation/components/pipeline/asset.yaml @@ -0,0 +1,3 @@ +type: component +spec: spec.yaml +categories: ["Foundational Models", "Finetune", "Distillation"] diff --git a/assets/training/distillation/components/pipeline/spec.yaml b/assets/training/distillation/components/pipeline/spec.yaml new file mode 100644 index 0000000000..4dfe6e3622 --- /dev/null +++ b/assets/training/distillation/components/pipeline/spec.yaml @@ -0,0 +1,311 @@ +$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json +name: oss_distillation_pipeline +version: 0.0.1 +type: pipeline + + +display_name: OSS Distillation Pipeline +description: Component to generate data from teacher model enpoint and finetune student model on generated dataset + +inputs: + # Compute parameters + instance_type_data_generation: + type: string + optional: true + default: Standard_D4as_v4 + description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used + instance_type_data_import: + type: string + optional: true + default: Singularity.ND96amrs_A100_v4 + description: Instance type to be used for data_import component in case of virtual cluster compute, eg. Singularity.D8_v3. The parameter compute_data_import must be set to 'serverless' for instance_type to be used + instance_type_finetune: + type: string + optional: true + default: Singularity.ND96amrs_A100_v4 + description: Instance type to be used for finetune component in case of virtual cluster compute, eg. Singularity.ND40_v2. The parameter compute_finetune must be set to 'serverless' for instance_type to be used + + compute_data_generation: + type: string + optional: true + default: 'serverless' + description: >- + compute to be used for model_import eg. provide 'FT-Cluster' if + your compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value. + If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used + compute_data_import: + type: string + optional: true + default: 'serverless' + description: >- + compute to be used for model_import eg. provide 'FT-Cluster' if + your compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value. + If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used + compute_finetune: + type: string + optional: true + default: 'serverless' + description: >- + compute to be used for finetune eg. provide 'FT-Cluster' if your + compute is named 'FT-Cluster'. Special characters like \ and ' are invalid in the parameter value. + If compute cluster name is provided, instance_type field will be ignored and the respective cluster will be used + + + ## OSS Data generator Input Parameters + train_file_path: + type: uri_file + description: Path to the registered training data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`. + mode: rw_mount + + validation_file_path: + type: uri_file + optional: true + description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`. + mode: rw_mount + + teacher_model_endpoint_name: + type: string + optional: true + description: Teacher model endpoint name + + teacher_model_endpoint_url: + type: string + optional: true + description: Teacher model endpoint URL + + teacher_model_endpoint_key: + type: string + optional: true + description: Teacher model endpoint key + + teacher_model_max_new_tokens: + type: integer + default: 128 + description: Teacher model max_new_tokens inference parameter + + teacher_model_temperature: + type: number + default: 0.2 + description: Teacher model temperature inference parameter + + teacher_model_top_p: + type: number + default: 0.1 + description: Teacher model top_p inference parameter + + teacher_model_frequency_penalty: + type: number + default: 0.0 + description: Teacher model frequency penalty inference parameter + + teacher_model_presence_penalty: + type: number + default: 0.0 + description: Teacher model presence penalty inference parameter + + teacher_model_stop: + type: string + optional: true + description: Teacher model stop inference parameter + + request_batch_size: + type: integer + default: 10 + description: No of data records to hit teacher model endpoint in one go + + min_endpoint_success_ratio: + type: number + default: 0.7 + description: > + The minimum value of (successful_requests / total_requests) required for classifying inference as successful. + If (successful_requests / total_requests) < min_endpoint_success_ratio, the experiment will be marked as failed. + By default it is 0.7 (0 means all requests are allowed to fail while 1 means no request should fail.) + + enable_chain_of_thought: + type: string + default: "false" + description: Enable Chain of thought for data generation + + ## OSS Finetune Input Parameters + number_of_gpu_to_use_finetuning: + type: integer + default: 1 + optional: true + description: >- + number of gpus to be used per node for finetuning, should be equal + to number of gpu per node in the compute SKU used for finetune + + # Continual-Finetuning model path + mlflow_model_path: + type: mlflow_model + optional: true + description: MLflow model asset path. Special characters like \ and ' are invalid in the parameter value. + mode: download + pytorch_model_path: + type: custom_model + optional: true + description: Pytorch model asset path. Special characters like \ and ' are invalid in the parameter value. + mode: download + + # Training parameters + num_train_epochs: + type: integer + default: 1 + optional: true + description: training epochs + + per_device_train_batch_size: + type: integer + default: 1 + optional: true + description: Train batch size + + learning_rate: + type: number + default: 3e-04 + optional: true + description: Start learning rate. + + # Validation parameters + system_properties: + type: string + optional: true + description: Validation parameters propagated from pipeline. + + # Model parameters + model_asset_id: + type: string + optional: false + description: Asset id of model + + # Model registration + registered_model_name: + type: string + optional: true + description: Name of the registered model + +outputs: + output_model: + type: uri_folder + description: Output dir to save the finetuned lora weights + mode: rw_mount + +jobs: + oss_distillation_generate_data: + type: command + component: azureml:oss_distillation_generate_data:0.0.1 + compute: '${{parent.inputs.compute_data_generation}}' + resources: + instance_type: '${{parent.inputs.instance_type_data_generation}}' + properties: + singularity: + imageVersion: '' + SLATier: 'Premium' + priority: 'Medium' + identity: + type: user_identity + inputs: + train_file_path: '${{parent.inputs.train_file_path}}' + validation_file_path: '${{parent.inputs.validation_file_path}}' + teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}' + teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}' + teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}' + enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}' + teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}' + teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}' + teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}' + teacher_model_frequency_penalty: '${{parent.inputs.teacher_model_frequency_penalty}}' + teacher_model_presence_penalty: '${{parent.inputs.teacher_model_presence_penalty}}' + request_batch_size: '${{parent.inputs.request_batch_size}}' + min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}' + outputs: + generated_train_file_path: + type: uri_file + path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl + generated_validation_file_path: + type: uri_file + path: azureml://datastores/${{default_datastore}}/paths/azureml/${{name}}/${{output_name}}.jsonl + + oss_text_generation_data_import: + type: command + component: azureml:oss_text_generation_data_import:0.0.19 + compute: '${{parent.inputs.compute_data_import}}' + resources: + instance_type: '${{parent.inputs.instance_type_data_import}}' + properties: + singularity: + imageVersion: '' + SLATier: 'Premium' + priority: 'Medium' + environment_variables: + _AZUREML_CR_ENABLE_ITP_CAP: "false" + inputs: + train_file_path: '${{parent.jobs.oss_distillation_generate_data.outputs.generated_train_file_path}}' + validation_file_path: '${{parent.jobs.oss_distillation_generate_data.outputs.generated_validation_file_path}}' + system_properties: '${{parent.inputs.system_properties}}' + + oss_chat_completion_finetune: + type: command + component: azureml:oss_chat_completion_finetune:0.0.19 + compute: '${{parent.inputs.compute_finetune}}' + resources: + instance_type: '${{parent.inputs.instance_type_finetune}}' + properties: + singularity: + imageVersion: '' + SLATier: 'Premium' + priority: 'Medium' + environment_variables: + _AZUREML_CR_ENABLE_ITP_CAP: "false" + inputs: + task_name: "ChatCompletion" + mlflow_model_path: '${{parent.inputs.mlflow_model_path}}' + model_asset_id: '${{parent.inputs.model_asset_id}}' + pytorch_model_path: '${{parent.inputs.pytorch_model_path}}' + dataset_input: '${{parent.jobs.oss_text_generation_data_import.outputs.output_dataset}}' + batch_size: 1000 + pad_to_max_length: "false" + max_seq_length: 8192 + number_of_gpu_to_use_finetuning: '${{parent.inputs.number_of_gpu_to_use_finetuning}}' + apply_lora: "true" + lora_alpha: 128 + lora_r: 8 + lora_dropout: 0 + num_train_epochs: '${{parent.inputs.num_train_epochs}}' + max_steps: -1 + per_device_train_batch_size: '${{parent.inputs.per_device_train_batch_size}}' + per_device_eval_batch_size: '${{parent.inputs.per_device_train_batch_size}}' + auto_find_batch_size: "false" + optim: adamw_hf + learning_rate: '${{parent.inputs.learning_rate}}' + warmup_steps: 0 + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_epsilon: 1e-05 + gradient_accumulation_steps: 1 + eval_accumulation_steps: 1 + lr_scheduler_type: cosine + precision: 16 + seed: 42 + enable_full_determinism: "false" + dataloader_num_workers: 0 + ignore_mismatched_sizes: "false" + max_grad_norm: 1.0 + evaluation_strategy: epoch + evaluation_steps_interval: 0.0 + eval_steps: 500 + logging_strategy: steps + logging_steps: 10 + metric_for_best_model: loss + resume_from_checkpoint: "false" + save_total_limit: 1 + apply_early_stopping: "false" + early_stopping_patience: 0 + apply_deepspeed: "true" + deepspeed_stage: 3 + apply_ort: "false" + system_properties: '${{parent.inputs.system_properties}}' + registered_model_name: '${{parent.inputs.registered_model_name}}' + outputs: + output_model: '${{parent.outputs.output_model}}'