-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Roopa G
committed
Nov 20, 2024
1 parent
a13de9f
commit bf52570
Showing
10 changed files
with
484 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
17 changes: 17 additions & 0 deletions
17
assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/conda.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
channels: | ||
- conda-forge | ||
dependencies: | ||
- python=3.10.14 | ||
- pip<=24.0 | ||
- pip: | ||
- mlflow==2.13.2 | ||
- cffi==1.16.0 | ||
- cloudpickle==2.2.1 | ||
- numpy==1.23.5 | ||
- pandas==2.2.2 | ||
- pyyaml==6.0.1 | ||
- requests==2.32.3 | ||
- timm==1.0.9,>=0.9.11 | ||
- torch>2 | ||
- pillow>=10 | ||
name: mlflow-env |
33 changes: 33 additions & 0 deletions
33
assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
"""Common Config.""" | ||
|
||
from enum import Enum | ||
|
||
from mlflow.types import DataType | ||
|
||
|
||
class _CustomEnum(Enum): | ||
@classmethod | ||
def has_value(cls, value): | ||
return value in cls._value2member_map_ | ||
|
||
class MLflowSchemaLiterals: | ||
"""MLflow model signature related schema.""" | ||
|
||
INPUT_COLUMN_IMAGE_DATA_TYPE = DataType.binary | ||
INPUT_COLUMN_IMAGE = "image" | ||
INPUT_COLUMN_TEXT_DATA_TYPE = DataType.string | ||
INPUT_COLUMN_TEXT = "text" | ||
OUTPUT_COLUMN_DATA_TYPE = DataType.string | ||
OUTPUT_COLUMN_PROBS = "probs" | ||
OUTPUT_COLUMN_LABELS = "labels" | ||
OUTPUT_COLUMN_IMAGE_FEATURES = "image_features" | ||
OUTPUT_COLUMN_TEXT_FEATURES = "text_features" | ||
|
||
|
||
class MLflowLiterals: | ||
"""MLflow export related literals.""" | ||
|
||
MODEL_DIR = "model_dir" |
71 changes: 71 additions & 0 deletions
71
...nagement/src/azureml/model/mgmt/processors/pyfunc/virchow/virchow_mlflow_model_wrapper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import json | ||
|
||
import mlflow.pyfunc | ||
import timm | ||
import torch | ||
import pandas as pd | ||
import io | ||
from PIL import Image | ||
from timm.data import resolve_data_config | ||
from timm.data.transforms_factory import create_transform | ||
from timm.layers import SwiGLUPacked | ||
from config import MLflowSchemaLiterals | ||
import logging | ||
logger = logging.getLogger("mlflow") # Set log level to debugging | ||
logger.setLevel(logging.DEBUG) | ||
|
||
class VirchowModelWrapper(mlflow.pyfunc.PythonModel): | ||
def load_context(self, context): | ||
config_path = context.artifacts["config_path"] | ||
checkpoint_path = context.artifacts["checkpoint_path"] | ||
# config = json.loads(config_path.read_text()) | ||
with open(config_path) as f: | ||
config = json.load(f) | ||
self.model = timm.create_model( | ||
model_name="vit_huge_patch14_224", | ||
checkpoint_path=checkpoint_path, | ||
mlp_layer=SwiGLUPacked, | ||
act_layer=torch.nn.SiLU, | ||
pretrained_cfg=config["pretrained_cfg"], | ||
**config["model_args"] | ||
) | ||
self.model.eval() | ||
self.transforms = create_transform( | ||
**resolve_data_config(self.model.pretrained_cfg, model=self.model) | ||
) | ||
|
||
# def predict(self, image_input_path: str, params: dict = None): | ||
def predict(self, context: mlflow.pyfunc.PythonModelContext, input_data: pd.DataFrame, params: pd.DataFrame) -> pd.DataFrame: | ||
|
||
from vision_utils import process_image | ||
pil_images = [ | ||
Image.open(io.BytesIO(process_image(image))) | ||
for image in input_data[MLflowSchemaLiterals.INPUT_COLUMN_IMAGE] | ||
] | ||
# image = input_data["image"] | ||
# pil_image = Image.open(io.BytesIO(process_image(pil_images[0]))) | ||
pil_image = self.transforms(pil_images[0]).unsqueeze(0) # size: 1 x 3 x 224 x 224 | ||
|
||
device_type = params.get("device_type", "cuda") | ||
to_half_precision = params.get("to_half_precision", False) | ||
|
||
with torch.inference_mode(), torch.autocast( | ||
device_type=device_type, dtype=torch.float16 | ||
): | ||
output = self.model(pil_image) # size: 1 x 257 x 1280 | ||
|
||
class_token = output[:, 0] # size: 1 x 1280 | ||
patch_tokens = output[:, 1:] # size: 1 x 256 x 1280 | ||
|
||
# use the class token only as the embedding | ||
# size: 1 x 1280 | ||
embedding = class_token | ||
|
||
# the model output will be fp32 because the final operation is a LayerNorm that is ran in mixed precision | ||
# optionally, you can convert the embedding to fp16 for efficiency in downstream use | ||
if to_half_precision: | ||
embedding = embedding.to(torch.float16) | ||
|
||
df_result = pd.DataFrame() | ||
df_result['output'] = embedding.tolist() | ||
return df_result |
Oops, something went wrong.