Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes for accepting WS connection #3750

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ inputs:
description: An optional configuration file to use for deployment settings. This overrides passed in parameters.
scoring_url:
type: string
optional: false
optional: true
description: The URL of the endpoint.
model_type:
type: string
Expand Down Expand Up @@ -89,7 +89,7 @@ code: ../src
environment: azureml://registries/azureml/environments/evaluation/labels/latest
command: >-
python -m aml_benchmark.batch_config_generator.main
--scoring_url '${{inputs.scoring_url}}'
$[[--scoring_url '${{inputs.scoring_url}}'']]
--model_type ${{inputs.model_type}}
--authentication_type ${{inputs.authentication_type}}
--debug_mode ${{inputs.debug_mode}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ is_deterministic: True
display_name: OSS Distillation Generate Data
description: Component to generate data from teacher model enpoint

environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/76
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/81

inputs:
# Inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,15 @@ inputs:
description: Path to the registered validation data asset. The supported data formats are `jsonl`, `json`, `csv`, `tsv` and `parquet`.
mode: rw_mount

teacher_model_endpoint_url:
type: string
optional: true
description: Teacher model endpoint URL

teacher_model_endpoint_name:
type: string
optional: true
description: Teacher model endpoint name

teacher_model_endpoint_key:
teacher_model_connection_name:
type: string
optional: true
description: Teacher model endpoint key
description: Teacher model connection name

teacher_model_max_new_tokens:
type: integer
Expand Down Expand Up @@ -266,8 +261,7 @@ jobs:
train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
teacher_model_connection_name: '${{parent.inputs.teacher_model_connection_name}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
max_len_summary: '${{parent.inputs.max_len_summary}}'
Expand Down Expand Up @@ -303,7 +297,6 @@ jobs:
identity:
type: user_identity
inputs:
scoring_url: ${{parent.inputs.teacher_model_endpoint_url}}
deployment_name: ${{parent.inputs.teacher_model_endpoint_name}}
authentication_type: ${{parent.inputs.authentication_type}}
configuration_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}}
Expand Down Expand Up @@ -408,6 +401,7 @@ jobs:
hash_validation_data: '${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.hash_validation_data}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
enable_chain_of_density: '${{parent.inputs.enable_chain_of_density}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
data_generation_task_type: '${{parent.inputs.data_generation_task_type}}'
min_endpoint_success_ratio: '${{parent.inputs.min_endpoint_success_ratio}}'
connection_config_file: ${{parent.jobs.oss_distillation_generate_data_batch_preprocess.outputs.batch_config_connection}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ is_deterministic: False
display_name: OSS Distillation Generate Data Postprocess Batch Scoring
description: Component to prepare data returned from teacher model enpoint in batch

environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/76
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/81

inputs:
# Inputs
Expand Down Expand Up @@ -62,6 +62,11 @@ inputs:
default: "false"
description: Enable Chain of density for text summarization

teacher_model_endpoint_name:
type: string
optional: true
description: Teacher model endpoint name

data_generation_task_type:
type: string
enum:
Expand Down Expand Up @@ -104,6 +109,7 @@ command: >-
--min_endpoint_success_ratio ${{inputs.min_endpoint_success_ratio}}
$[[--enable_chain_of_thought ${{inputs.enable_chain_of_thought}}]]
$[[--enable_chain_of_density ${{inputs.enable_chain_of_density}}]]
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]]
--data_generation_task_type ${{inputs.data_generation_task_type}}
--connection_config_file ${{inputs.connection_config_file}}
--generated_batch_train_file_path ${{outputs.generated_batch_train_file_path}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ is_deterministic: False
display_name: OSS Distillation Generate Data Batch Scoring Preprocess
description: Component to prepare data to invoke teacher model enpoint in batch

environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/76
environment: azureml://registries/azureml/environments/acft-hf-nlp-gpu/versions/81

inputs:
# Inputs
Expand All @@ -28,12 +28,7 @@ inputs:
optional: true
description: Teacher model endpoint name

teacher_model_endpoint_url:
type: string
optional: true
description: Teacher model endpoint url

teacher_model_endpoint_key:
teacher_model_connection_name:
type: string
optional: true
description: Teacher model endpoint key
Expand Down Expand Up @@ -133,8 +128,7 @@ command: >-
--train_file_path ${{inputs.train_file_path}}
$[[--validation_file_path ${{inputs.validation_file_path}}]]
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]]
$[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]]
$[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]]
$[[--teacher_model_connection_name ${{inputs.teacher_model_connection_name}}]]
--teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}}
--teacher_model_temperature ${{inputs.teacher_model_temperature}}
--teacher_model_top_p ${{inputs.teacher_model_top_p}}
Expand Down
9 changes: 7 additions & 2 deletions assets/training/distillation/components/pipeline/spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ inputs:
optional: true
description: Teacher model endpoint name

teacher_model_connection_name:
type: string
optional: true
description: Teacher model connection name

teacher_model_endpoint_url:
type: string
optional: true
Expand Down Expand Up @@ -307,6 +312,7 @@ jobs:
train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_connection_name: '${{parent.inputs.teacher_model_connection_name}}'
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
enable_chain_of_thought: '${{parent.inputs.enable_chain_of_thought}}'
Expand Down Expand Up @@ -359,9 +365,8 @@ jobs:
compute_finetune: '${{parent.inputs.compute_finetune}}'
train_file_path: '${{parent.inputs.train_file_path}}'
validation_file_path: '${{parent.inputs.validation_file_path}}'
teacher_model_endpoint_url: '${{parent.inputs.teacher_model_endpoint_url}}'
teacher_model_endpoint_name: '${{parent.inputs.teacher_model_endpoint_name}}'
teacher_model_endpoint_key: '${{parent.inputs.teacher_model_endpoint_key}}'
teacher_model_connection_name: '${{parent.inputs.teacher_model_connection_name}}'
teacher_model_max_new_tokens: '${{parent.inputs.teacher_model_max_new_tokens}}'
teacher_model_temperature: '${{parent.inputs.teacher_model_temperature}}'
teacher_model_top_p: '${{parent.inputs.teacher_model_top_p}}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ inputs:
optional: true
description: Teacher model endpoint key

teacher_model_connection_name:
type: string
optional: true
description: Teacher model connection name

teacher_model_max_new_tokens:
type: integer
default: 128
Expand Down Expand Up @@ -148,6 +153,7 @@ command: >-
$[[--teacher_model_endpoint_name ${{inputs.teacher_model_endpoint_name}}]]
$[[--teacher_model_endpoint_url ${{inputs.teacher_model_endpoint_url}}]]
$[[--teacher_model_endpoint_key ${{inputs.teacher_model_endpoint_key}}]]
$[[--teacher_model_connection_name ${{inputs.teacher_model_connection_name}}]]
--teacher_model_max_new_tokens ${{inputs.teacher_model_max_new_tokens}}
--teacher_model_temperature ${{inputs.teacher_model_temperature}}
--teacher_model_top_p ${{inputs.teacher_model_top_p}}
Expand Down
179 changes: 179 additions & 0 deletions assets/training/distillation/src/common/api_connection_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""API key helper functions."""

import json
import requests
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from typing import Tuple, Optional

from azureml.core import Run, Workspace
from azureml.core.run import _OfflineRun
from azureml._common._error_definition.azureml_error import AzureMLError

from azureml.acft.common_components.utils.error_handling.exceptions import (
ACFTSystemException,
ACFTValidationException,
)
from azureml.acft.common_components.utils.error_handling.error_definitions import (
ACFTSystemError,
ACFTUserError,
)


def _create_session_with_retry(retry: int = 3) -> requests.Session:
"""
Create requests.session with retry.

:type retry: int
rtype: Response
"""
retry_policy = _get_retry_policy(num_retry=retry)

session = requests.Session()
session.mount("https://", HTTPAdapter(max_retries=retry_policy))
session.mount("http://", HTTPAdapter(max_retries=retry_policy))
return session


def _get_retry_policy(num_retry: int = 3) -> Retry:
"""
Request retry policy with increasing backoff.

:return: Returns the msrest or requests REST client retry policy.
:rtype: urllib3.Retry
"""
backoff_factor = 0.4
retry_policy = Retry(
total=num_retry,
read=num_retry,
connect=num_retry,
backoff_factor=backoff_factor,
status_forcelist={413, 429, 500, 502, 503, 504, None},
# By default this is True. We set it to false to get the full error trace, including url and
# status code of the last retry. Otherwise, the error message is too many 500 error responses',
# which is not useful.
raise_on_status=False,
)
return retry_policy


def _send_post_request(url: str, headers: dict, payload: dict):
"""Send a POST request."""
try:
with _create_session_with_retry() as session:
response = session.post(url, data=json.dumps(payload), headers=headers)
# Raise an exception if the response contains an HTTP error status code
response.raise_for_status()
except requests.exceptions.HTTPError as errh:
raise ACFTSystemException._with_error(
AzureMLError.create(ACFTSystemError, error_details=f"HTTP Error: {errh}")
)
return response


def get_target_from_connection(connections_name: str) -> Tuple[str, Optional[str]]:
"""
Get target from connections_name.

:param connections_name: Name of the connection.
:return: target.
"""
run = Run.get_context()
if isinstance(run, _OfflineRun):
curr_ws = Workspace.from_config("config.json")
else:
curr_ws = run.experiment.workspace

if hasattr(curr_ws._auth, "get_token"):
bearer_token = curr_ws._auth.get_token(
"https://management.azure.com/.default"
).token
else:
bearer_token = curr_ws._auth.token

endpoint = curr_ws.service_context._get_endpoint("api")
url_list = [
endpoint,
"rp/workspaces/subscriptions",
curr_ws.subscription_id,
"resourcegroups",
curr_ws.resource_group,
"providers",
"Microsoft.MachineLearningServices",
"workspaces",
curr_ws.name,
"connections",
connections_name,
"listsecrets?api-version=2023-02-01-preview",
]

resp = _send_post_request(
"/".join(url_list),
{"Authorization": f"Bearer {bearer_token}", "content-type": "application/json"},
{},
)
target = resp.json().get("properties", {}).get("target")
if target is None:
msg = "Target not found in response"
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(msg),
)
)
return target


def get_api_key_from_connection(connections_name: str) -> Tuple[str, Optional[str]]:
"""
Get api_key from connections_name.

:param connections_name: Name of the connection.
:return: api_key, api_version.
"""
run = Run.get_context()
if isinstance(run, _OfflineRun):
curr_ws = Workspace.from_config()
else:
curr_ws = run.experiment.workspace

if hasattr(curr_ws._auth, "get_token"):
bearer_token = curr_ws._auth.get_token(
"https://management.azure.com/.default"
).token
else:
bearer_token = curr_ws._auth.token

endpoint = curr_ws.service_context._get_endpoint("api")
url_list = [
endpoint,
"rp/workspaces/subscriptions",
curr_ws.subscription_id,
"resourcegroups",
curr_ws.resource_group,
"providers",
"Microsoft.MachineLearningServices",
"workspaces",
curr_ws.name,
"connections",
connections_name,
"listsecrets?api-version=2023-02-01-preview",
]

resp = _send_post_request(
"/".join(url_list),
{"Authorization": f"Bearer {bearer_token}", "content-type": "application/json"},
{},
)

credentials = resp.json()["properties"]["credentials"]
metadata = resp.json()["properties"].get("metadata", {})
if "key" in credentials:
return credentials["key"], metadata.get("ApiVersion")
else:
if "secretAccessKey" not in credentials and "keys" in credentials:
credentials = credentials["keys"]
return credentials["secretAccessKey"], None
1 change: 0 additions & 1 deletion assets/training/distillation/src/dsl_condition_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def copy_file_contents(input_src1, ft_input_train_file_path):
parser.add_argument("--ft_input_validation_file_path", type=str)

args, _ = parser.parse_known_args()
print(f"Condition output component received args: {args}.")
if (
args.generated_batch_train_file_path is None
and args.generated_train_file_path is None
Expand Down
Loading
Loading