Skip to content

Commit

Permalink
Adding flavor map to model selector args (#1591)
Browse files Browse the repository at this point in the history
* Adding flavor map to model selector args

* Resolving code health errors

* Resolving code health errors

* Resolving code health errors

* Adddressing comments

---------

Co-authored-by: Anubha98 <[email protected]>
  • Loading branch information
Anubha98 and Anubha98 authored Oct 31, 2023
1 parent 859f598 commit 2a18224
Showing 1 changed file with 24 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,26 @@ class ModelImportConstants:
}


FLAVOR_MAP = {
# OSS Flavor
"transformers": {
"tokenizer": "components/tokenizer",
"model": "model",
"config": "model"
},
"hftransformersv2": {
"tokenizer": "data/tokenizer",
"model": "data/model",
"config": "data/config"
},
"hftransformers": {
"tokenizer": "data/tokenizer",
"model": "data/model",
"config": "data/config"
}
}


def get_model_asset_id() -> str:
"""Read the model asset id from the run context.
Expand Down Expand Up @@ -421,7 +441,7 @@ def model_selector(args: Namespace):
mlflow_data = yaml.safe_load(fp)
if mlflow_data and "flavors" in mlflow_data:
for key in mlflow_data["flavors"]:
if key in ["hftransformers", "hftransformersv2"]:
if key in FLAVOR_MAP.keys():
for key2 in mlflow_data["flavors"][key]:
if key2 == "generator_config" and args.task_name == "TextGeneration":
generator_config = mlflow_data["flavors"][key]["generator_config"]
Expand Down Expand Up @@ -480,6 +500,9 @@ def main():
azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
)

# Adding flavor map to args
setattr(args, "flavor_map", FLAVOR_MAP)

model_selector(args)


Expand Down

0 comments on commit 2a18224

Please sign in to comment.