Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Roopa G committed Nov 20, 2024
1 parent a13de9f commit bf52570
Show file tree
Hide file tree
Showing 10 changed files with 484 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ inputs:
- mask-generation
- video-multi-object-tracking
- visual-question-answering
- image-feature-extraction
description: A Hugging face task on which model was trained on. A required parameter for transformers MLflow flavor. Can be provided as input here or in model_download_metadata JSON file.
optional: true

Expand Down
23 changes: 12 additions & 11 deletions assets/training/model_management/components/import_model/spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ inputs:
- mask-generation
- video-multi-object-tracking
- visual-question-answering
- image-feature-extraction
optional: true
type: string

Expand Down Expand Up @@ -218,8 +219,8 @@ outputs:

jobs:
validation_trigger_import:
component: azureml:validation_trigger_import:0.0.13
# component: azureml://registries/azureml/components/validation_trigger_import/versions/0.0.11
# component: azureml:validation_trigger_import:0.0.13
component: azureml://registries/azureml/components/validation_trigger_import/versions/0.0.11
compute: ${{parent.inputs.compute}}
resources:
instance_type: '${{parent.inputs.instance_type}}'
Expand Down Expand Up @@ -253,8 +254,8 @@ jobs:
type: uri_file

download_model:
component: azureml:download_model:0.0.29
# component: azureml://registries/azureml/components/download_model/versions/0.0.28
# component: azureml:download_model:0.0.29
component: azureml://registries/azureml/components/download_model/versions/0.0.28
compute: ${{parent.inputs.compute}}
resources:
instance_type: '${{parent.inputs.instance_type}}'
Expand All @@ -273,7 +274,7 @@ jobs:
type: uri_folder

mmdetection_image_objectdetection_instancesegmentation_model_import:
component: azureml:mmdetection_image_objectdetection_instancesegmentation_model_import/versions/0.0.19
component: azureml://registries/azureml/components/mmdetection_image_objectdetection_instancesegmentation_model_import/versions/0.0.19
compute: ${{parent.inputs.compute}}
resources:
instance_type: '${{parent.inputs.instance_type}}'
Expand All @@ -288,8 +289,8 @@ jobs:
type: uri_file

convert_model_to_mlflow:
component: azureml:convert_model_to_mlflow:0.0.34
# component: azureml://registries/azureml-preview-test1/components/convert_model_to_mlflow/versions/0.0.33-groopa-test
# component: azureml:convert_model_to_mlflow:0.0.34
component: azureml://registries/azureml-preview-test1/components/convert_model_to_mlflow/versions/0.0.39-groopa-test
compute: ${{parent.inputs.compute}}
resources:
instance_type: '${{parent.inputs.instance_type}}'
Expand Down Expand Up @@ -318,8 +319,8 @@ jobs:
type: mlflow_model

mlflow_model_local_validation:
component: azureml:mlflow_model_local_validation:0.0.16
# component: azureml://registries/azureml/components/mlflow_model_local_validation/versions/0.0.14
# component: azureml:mlflow_model_local_validation:0.0.16
component: azureml://registries/azureml/components/mlflow_model_local_validation/versions/0.0.14
compute: ${{parent.inputs.compute}}
resources:
instance_type: '${{parent.inputs.instance_type}}'
Expand All @@ -332,8 +333,8 @@ jobs:
mlflow_model_folder: ${{parent.outputs.mlflow_model_folder}}

register_model:
component: azureml:register_model:0.0.19
# component: azureml://registries/azureml/components/register_model/versions/0.0.17
# component: azureml:register_model:0.0.19
component: azureml://registries/azureml/components/register_model/versions/0.0.17
compute: ${{parent.inputs.compute}}
resources:
instance_type: '${{parent.inputs.instance_type}}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ inputs:
- image-classification
- text-to-image
- chat-completion
- image-feature-extraction
optional: true
type: string

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
DinoV2MLFlowConvertor,
LLaVAMLFlowConvertor,
SegmentAnythingMLFlowConvertor,
VirchowMLFlowConvertor
)


Expand Down Expand Up @@ -84,6 +85,10 @@ def get_mlflow_convertor(model_framework, model_dir, output_dir, temp_dir, trans
return SegmentAnythingMLflowConvertorFactory.create_mlflow_convertor(
model_dir, output_dir, temp_dir, translate_params
)
elif task == PyFuncSupportedTasks.IMAGE_FEATURE_EXTRACTION.value:
return VirchowMLflowConvertorFactory.create_mlflow_convertor(
model_dir, output_dir, temp_dir, translate_params
)
else:
raise Exception(
f"Models from {model_framework} for task {task} and model {model_id} "
Expand Down Expand Up @@ -297,7 +302,18 @@ def create_mlflow_convertor(model_dir, output_dir, temp_dir, translate_params):
temp_dir=temp_dir,
translate_params=translate_params,
)

class VirchowMLflowConvertorFactory(MLflowConvertorFactoryInterface):
"""Factory class for segment anything Virchow model."""

def create_mlflow_convertor(model_dir, output_dir, temp_dir, translate_params):
"""Create MLflow convertor for segment anything Virchow model."""
return VirchowMLFlowConvertor(
model_dir=model_dir,
output_dir=output_dir,
temp_dir=temp_dir,
translate_params=translate_params,
)

class MMLabTrackingMLflowConvertorFactory(MLflowConvertorFactoryInterface):
"""Factory class for MMTrack video model family."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class SupportedTasks(_CustomEnum):
IMAGE_OBJECT_DETECTION = "image-object-detection"
IMAGE_INSTANCE_SEGMENTATION = "image-instance-segmentation"

# Virchow
IMAGE_FEATURE_EXTRACTION = "image-feature-extraction"

class ModelFamilyPrefixes(_CustomEnum):
"""Prefixes for some of the models converted to PyFunc MLflow."""
Expand All @@ -65,3 +67,6 @@ class ModelFamilyPrefixes(_CustomEnum):

# DinoV2 model family.
DINOV2 = "facebook/dinov2"

# Virchow model family.
VIRCHOW = "paige-ai/Virchow"
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
MLflowSchemaLiterals as SegmentAnythingMLFlowSchemaLiterals, MLflowLiterals as SegmentAnythingMLflowLiterals
from azureml.model.mgmt.processors.pyfunc.vision.config import \
MLflowSchemaLiterals as VisionMLFlowSchemaLiterals, MMDetLiterals
from azureml.model.mgmt.processors.pyfunc.virchow.config import \
MLflowSchemaLiterals as VirchowMLFlowSchemaLiterals, MLflowLiterals as VirchowMLflowLiterals


logger = get_logger(__name__)
Expand Down Expand Up @@ -1136,3 +1138,95 @@ def save_as_mlflow(self):
conda_env=conda_env_file,
code_path=None,
)


class VirchowMLFlowConvertor(PyFuncMLFLowConvertor):
"""PyFunc MLfLow convertor for Virchow models."""

MODEL_DIR = os.path.join(os.path.dirname(__file__), "virchow")
COMMON_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "common")

def __init__(self, **kwargs):
"""Initialize MLflow convertor for Virchow models."""
super().__init__(**kwargs)
if self._task not in \
[SupportedTasks.IMAGE_FEATURE_EXTRACTION.value]:
raise Exception("Unsupported task")

def get_model_signature(self) -> ModelSignature:
"""Return MLflow model signature with input and output schema for the given input task.
:return: MLflow model signature.
:rtype: mlflow.models.signature.ModelSignature
"""
input_schema = Schema(
[
ColSpec(VirchowMLFlowSchemaLiterals.INPUT_COLUMN_IMAGE_DATA_TYPE,
VirchowMLFlowSchemaLiterals.INPUT_COLUMN_IMAGE),
ColSpec(VirchowMLFlowSchemaLiterals.INPUT_COLUMN_TEXT_DATA_TYPE,
VirchowMLFlowSchemaLiterals.INPUT_COLUMN_TEXT),
]
)

if self._task == SupportedTasks.IMAGE_FEATURE_EXTRACTION.value:
output_schema = Schema(
[
ColSpec(VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_DATA_TYPE,
VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_PROBS),
ColSpec(VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_DATA_TYPE,
VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_LABELS),
ColSpec(VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_DATA_TYPE,
VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_IMAGE_FEATURES),
ColSpec(VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_DATA_TYPE,
VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_TEXT_FEATURES),
]
)
else:
raise Exception("Unsupported task")

return ModelSignature(inputs=input_schema, outputs=output_schema)

def save_as_mlflow(self):
"""Prepare model for save to MLflow."""
sys.path.append(self.MODEL_DIR)

from virchow_mlflow_model_wrapper import VirchowModelWrapper
mlflow_model_wrapper = VirchowModelWrapper()

artifacts_dict = self._prepare_artifacts_dict()
conda_env_file = os.path.join(self.MODEL_DIR, "conda.yaml")
code_path = self._get_code_path()

super()._save(
mlflow_model_wrapper=mlflow_model_wrapper,
artifacts_dict=artifacts_dict,
conda_env=conda_env_file,
code_path=code_path,
)

def _get_code_path(self):
"""Return code path for saving mlflow model depending on task type.
:return: code path
:rtype: List[str]
"""
code_path = [
os.path.join(self.MODEL_DIR, "virchow_mlflow_model_wrapper.py"),
os.path.join(self.MODEL_DIR, "config.py"),
os.path.join(self.COMMON_DIR, "vision_utils.py")
]

return code_path

def _prepare_artifacts_dict(self) -> Dict:
"""Prepare artifacts dict for MLflow model.
:return: artifacts dict
:rtype: Dict
"""
artifacts_dict = {
VirchowMLflowLiterals.MODEL_DIR: self._model_dir
}
return artifacts_dict


Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
channels:
- conda-forge
dependencies:
- python=3.10.14
- pip<=24.0
- pip:
- mlflow==2.13.2
- cffi==1.16.0
- cloudpickle==2.2.1
- numpy==1.23.5
- pandas==2.2.2
- pyyaml==6.0.1
- requests==2.32.3
- timm==1.0.9,>=0.9.11
- torch>2
- pillow>=10
name: mlflow-env
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Common Config."""

from enum import Enum

from mlflow.types import DataType


class _CustomEnum(Enum):
@classmethod
def has_value(cls, value):
return value in cls._value2member_map_

class MLflowSchemaLiterals:
"""MLflow model signature related schema."""

INPUT_COLUMN_IMAGE_DATA_TYPE = DataType.binary
INPUT_COLUMN_IMAGE = "image"
INPUT_COLUMN_TEXT_DATA_TYPE = DataType.string
INPUT_COLUMN_TEXT = "text"
OUTPUT_COLUMN_DATA_TYPE = DataType.string
OUTPUT_COLUMN_PROBS = "probs"
OUTPUT_COLUMN_LABELS = "labels"
OUTPUT_COLUMN_IMAGE_FEATURES = "image_features"
OUTPUT_COLUMN_TEXT_FEATURES = "text_features"


class MLflowLiterals:
"""MLflow export related literals."""

MODEL_DIR = "model_dir"
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import json

import mlflow.pyfunc
import timm
import torch
import pandas as pd
import io
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
from config import MLflowSchemaLiterals
import logging
logger = logging.getLogger("mlflow") # Set log level to debugging
logger.setLevel(logging.DEBUG)

class VirchowModelWrapper(mlflow.pyfunc.PythonModel):
def load_context(self, context):
config_path = context.artifacts["config_path"]
checkpoint_path = context.artifacts["checkpoint_path"]
# config = json.loads(config_path.read_text())
with open(config_path) as f:
config = json.load(f)
self.model = timm.create_model(
model_name="vit_huge_patch14_224",
checkpoint_path=checkpoint_path,
mlp_layer=SwiGLUPacked,
act_layer=torch.nn.SiLU,
pretrained_cfg=config["pretrained_cfg"],
**config["model_args"]
)
self.model.eval()
self.transforms = create_transform(
**resolve_data_config(self.model.pretrained_cfg, model=self.model)
)

# def predict(self, image_input_path: str, params: dict = None):
def predict(self, context: mlflow.pyfunc.PythonModelContext, input_data: pd.DataFrame, params: pd.DataFrame) -> pd.DataFrame:

from vision_utils import process_image
pil_images = [
Image.open(io.BytesIO(process_image(image)))
for image in input_data[MLflowSchemaLiterals.INPUT_COLUMN_IMAGE]
]
# image = input_data["image"]
# pil_image = Image.open(io.BytesIO(process_image(pil_images[0])))
pil_image = self.transforms(pil_images[0]).unsqueeze(0) # size: 1 x 3 x 224 x 224

device_type = params.get("device_type", "cuda")
to_half_precision = params.get("to_half_precision", False)

with torch.inference_mode(), torch.autocast(
device_type=device_type, dtype=torch.float16
):
output = self.model(pil_image) # size: 1 x 257 x 1280

class_token = output[:, 0] # size: 1 x 1280
patch_tokens = output[:, 1:] # size: 1 x 256 x 1280

# use the class token only as the embedding
# size: 1 x 1280
embedding = class_token

# the model output will be fp32 because the final operation is a LayerNorm that is ran in mixed precision
# optionally, you can convert the embedding to fp16 for efficiency in downstream use
if to_half_precision:
embedding = embedding.to(torch.float16)

df_result = pd.DataFrame()
df_result['output'] = embedding.tolist()
return df_result
Loading

0 comments on commit bf52570

Please sign in to comment.