From bf5257002f5fd618c69e0710ac534677d3a4c074 Mon Sep 17 00:00:00 2001 From: Roopa G Date: Wed, 20 Nov 2024 11:53:17 +0530 Subject: [PATCH] Fixes --- .../convert_model_to_mlflow/spec.yaml | 1 + .../components/import_model/spec.yaml | 23 +- .../validation_trigger_import/spec.yaml | 1 + .../azureml/model/mgmt/processors/factory.py | 16 ++ .../model/mgmt/processors/pyfunc/config.py | 5 + .../mgmt/processors/pyfunc/convertors.py | 94 +++++++ .../mgmt/processors/pyfunc/virchow/conda.yaml | 17 ++ .../mgmt/processors/pyfunc/virchow/config.py | 33 +++ .../virchow/virchow_mlflow_model_wrapper.py | 71 ++++++ .../processors/pyfunc/virchow/vision_utils.py | 234 ++++++++++++++++++ 10 files changed, 484 insertions(+), 11 deletions(-) create mode 100644 assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/conda.yaml create mode 100644 assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/config.py create mode 100644 assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/virchow_mlflow_model_wrapper.py create mode 100644 assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/vision_utils.py diff --git a/assets/training/model_management/components/convert_model_to_mlflow/spec.yaml b/assets/training/model_management/components/convert_model_to_mlflow/spec.yaml index db8f0d7af2..dd421bfd5e 100644 --- a/assets/training/model_management/components/convert_model_to_mlflow/spec.yaml +++ b/assets/training/model_management/components/convert_model_to_mlflow/spec.yaml @@ -79,6 +79,7 @@ inputs: - mask-generation - video-multi-object-tracking - visual-question-answering + - image-feature-extraction description: A Hugging face task on which model was trained on. A required parameter for transformers MLflow flavor. Can be provided as input here or in model_download_metadata JSON file. optional: true diff --git a/assets/training/model_management/components/import_model/spec.yaml b/assets/training/model_management/components/import_model/spec.yaml index b416f2822e..87b71e8828 100644 --- a/assets/training/model_management/components/import_model/spec.yaml +++ b/assets/training/model_management/components/import_model/spec.yaml @@ -100,6 +100,7 @@ inputs: - mask-generation - video-multi-object-tracking - visual-question-answering + - image-feature-extraction optional: true type: string @@ -218,8 +219,8 @@ outputs: jobs: validation_trigger_import: - component: azureml:validation_trigger_import:0.0.13 - # component: azureml://registries/azureml/components/validation_trigger_import/versions/0.0.11 + # component: azureml:validation_trigger_import:0.0.13 + component: azureml://registries/azureml/components/validation_trigger_import/versions/0.0.11 compute: ${{parent.inputs.compute}} resources: instance_type: '${{parent.inputs.instance_type}}' @@ -253,8 +254,8 @@ jobs: type: uri_file download_model: - component: azureml:download_model:0.0.29 - # component: azureml://registries/azureml/components/download_model/versions/0.0.28 + # component: azureml:download_model:0.0.29 + component: azureml://registries/azureml/components/download_model/versions/0.0.28 compute: ${{parent.inputs.compute}} resources: instance_type: '${{parent.inputs.instance_type}}' @@ -273,7 +274,7 @@ jobs: type: uri_folder mmdetection_image_objectdetection_instancesegmentation_model_import: - component: azureml:mmdetection_image_objectdetection_instancesegmentation_model_import/versions/0.0.19 + component: azureml://registries/azureml/components/mmdetection_image_objectdetection_instancesegmentation_model_import/versions/0.0.19 compute: ${{parent.inputs.compute}} resources: instance_type: '${{parent.inputs.instance_type}}' @@ -288,8 +289,8 @@ jobs: type: uri_file convert_model_to_mlflow: - component: azureml:convert_model_to_mlflow:0.0.34 - # component: azureml://registries/azureml-preview-test1/components/convert_model_to_mlflow/versions/0.0.33-groopa-test + # component: azureml:convert_model_to_mlflow:0.0.34 + component: azureml://registries/azureml-preview-test1/components/convert_model_to_mlflow/versions/0.0.39-groopa-test compute: ${{parent.inputs.compute}} resources: instance_type: '${{parent.inputs.instance_type}}' @@ -318,8 +319,8 @@ jobs: type: mlflow_model mlflow_model_local_validation: - component: azureml:mlflow_model_local_validation:0.0.16 - # component: azureml://registries/azureml/components/mlflow_model_local_validation/versions/0.0.14 + # component: azureml:mlflow_model_local_validation:0.0.16 + component: azureml://registries/azureml/components/mlflow_model_local_validation/versions/0.0.14 compute: ${{parent.inputs.compute}} resources: instance_type: '${{parent.inputs.instance_type}}' @@ -332,8 +333,8 @@ jobs: mlflow_model_folder: ${{parent.outputs.mlflow_model_folder}} register_model: - component: azureml:register_model:0.0.19 - # component: azureml://registries/azureml/components/register_model/versions/0.0.17 + # component: azureml:register_model:0.0.19 + component: azureml://registries/azureml/components/register_model/versions/0.0.17 compute: ${{parent.inputs.compute}} resources: instance_type: '${{parent.inputs.instance_type}}' diff --git a/assets/training/model_management/components/validation_trigger_import/spec.yaml b/assets/training/model_management/components/validation_trigger_import/spec.yaml index 94b87f080b..9db07a753c 100644 --- a/assets/training/model_management/components/validation_trigger_import/spec.yaml +++ b/assets/training/model_management/components/validation_trigger_import/spec.yaml @@ -67,6 +67,7 @@ inputs: - image-classification - text-to-image - chat-completion + - image-feature-extraction optional: true type: string diff --git a/assets/training/model_management/src/azureml/model/mgmt/processors/factory.py b/assets/training/model_management/src/azureml/model/mgmt/processors/factory.py index bb57653ebf..d100a613fd 100644 --- a/assets/training/model_management/src/azureml/model/mgmt/processors/factory.py +++ b/assets/training/model_management/src/azureml/model/mgmt/processors/factory.py @@ -36,6 +36,7 @@ DinoV2MLFlowConvertor, LLaVAMLFlowConvertor, SegmentAnythingMLFlowConvertor, + VirchowMLFlowConvertor ) @@ -84,6 +85,10 @@ def get_mlflow_convertor(model_framework, model_dir, output_dir, temp_dir, trans return SegmentAnythingMLflowConvertorFactory.create_mlflow_convertor( model_dir, output_dir, temp_dir, translate_params ) + elif task == PyFuncSupportedTasks.IMAGE_FEATURE_EXTRACTION.value: + return VirchowMLflowConvertorFactory.create_mlflow_convertor( + model_dir, output_dir, temp_dir, translate_params + ) else: raise Exception( f"Models from {model_framework} for task {task} and model {model_id} " @@ -297,7 +302,18 @@ def create_mlflow_convertor(model_dir, output_dir, temp_dir, translate_params): temp_dir=temp_dir, translate_params=translate_params, ) + +class VirchowMLflowConvertorFactory(MLflowConvertorFactoryInterface): + """Factory class for segment anything Virchow model.""" + def create_mlflow_convertor(model_dir, output_dir, temp_dir, translate_params): + """Create MLflow convertor for segment anything Virchow model.""" + return VirchowMLFlowConvertor( + model_dir=model_dir, + output_dir=output_dir, + temp_dir=temp_dir, + translate_params=translate_params, + ) class MMLabTrackingMLflowConvertorFactory(MLflowConvertorFactoryInterface): """Factory class for MMTrack video model family.""" diff --git a/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/config.py b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/config.py index 6637ff1204..a66a3ba06b 100644 --- a/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/config.py +++ b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/config.py @@ -56,6 +56,8 @@ class SupportedTasks(_CustomEnum): IMAGE_OBJECT_DETECTION = "image-object-detection" IMAGE_INSTANCE_SEGMENTATION = "image-instance-segmentation" + # Virchow + IMAGE_FEATURE_EXTRACTION = "image-feature-extraction" class ModelFamilyPrefixes(_CustomEnum): """Prefixes for some of the models converted to PyFunc MLflow.""" @@ -65,3 +67,6 @@ class ModelFamilyPrefixes(_CustomEnum): # DinoV2 model family. DINOV2 = "facebook/dinov2" + + # Virchow model family. + VIRCHOW = "paige-ai/Virchow" diff --git a/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/convertors.py b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/convertors.py index e20fc66285..c2dd452f44 100644 --- a/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/convertors.py +++ b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/convertors.py @@ -41,6 +41,8 @@ MLflowSchemaLiterals as SegmentAnythingMLFlowSchemaLiterals, MLflowLiterals as SegmentAnythingMLflowLiterals from azureml.model.mgmt.processors.pyfunc.vision.config import \ MLflowSchemaLiterals as VisionMLFlowSchemaLiterals, MMDetLiterals +from azureml.model.mgmt.processors.pyfunc.virchow.config import \ + MLflowSchemaLiterals as VirchowMLFlowSchemaLiterals, MLflowLiterals as VirchowMLflowLiterals logger = get_logger(__name__) @@ -1136,3 +1138,95 @@ def save_as_mlflow(self): conda_env=conda_env_file, code_path=None, ) + + +class VirchowMLFlowConvertor(PyFuncMLFLowConvertor): + """PyFunc MLfLow convertor for Virchow models.""" + + MODEL_DIR = os.path.join(os.path.dirname(__file__), "virchow") + COMMON_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "common") + + def __init__(self, **kwargs): + """Initialize MLflow convertor for Virchow models.""" + super().__init__(**kwargs) + if self._task not in \ + [SupportedTasks.IMAGE_FEATURE_EXTRACTION.value]: + raise Exception("Unsupported task") + + def get_model_signature(self) -> ModelSignature: + """Return MLflow model signature with input and output schema for the given input task. + + :return: MLflow model signature. + :rtype: mlflow.models.signature.ModelSignature + """ + input_schema = Schema( + [ + ColSpec(VirchowMLFlowSchemaLiterals.INPUT_COLUMN_IMAGE_DATA_TYPE, + VirchowMLFlowSchemaLiterals.INPUT_COLUMN_IMAGE), + ColSpec(VirchowMLFlowSchemaLiterals.INPUT_COLUMN_TEXT_DATA_TYPE, + VirchowMLFlowSchemaLiterals.INPUT_COLUMN_TEXT), + ] + ) + + if self._task == SupportedTasks.IMAGE_FEATURE_EXTRACTION.value: + output_schema = Schema( + [ + ColSpec(VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_DATA_TYPE, + VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_PROBS), + ColSpec(VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_DATA_TYPE, + VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_LABELS), + ColSpec(VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_DATA_TYPE, + VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_IMAGE_FEATURES), + ColSpec(VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_DATA_TYPE, + VirchowMLFlowSchemaLiterals.OUTPUT_COLUMN_TEXT_FEATURES), + ] + ) + else: + raise Exception("Unsupported task") + + return ModelSignature(inputs=input_schema, outputs=output_schema) + + def save_as_mlflow(self): + """Prepare model for save to MLflow.""" + sys.path.append(self.MODEL_DIR) + + from virchow_mlflow_model_wrapper import VirchowModelWrapper + mlflow_model_wrapper = VirchowModelWrapper() + + artifacts_dict = self._prepare_artifacts_dict() + conda_env_file = os.path.join(self.MODEL_DIR, "conda.yaml") + code_path = self._get_code_path() + + super()._save( + mlflow_model_wrapper=mlflow_model_wrapper, + artifacts_dict=artifacts_dict, + conda_env=conda_env_file, + code_path=code_path, + ) + + def _get_code_path(self): + """Return code path for saving mlflow model depending on task type. + + :return: code path + :rtype: List[str] + """ + code_path = [ + os.path.join(self.MODEL_DIR, "virchow_mlflow_model_wrapper.py"), + os.path.join(self.MODEL_DIR, "config.py"), + os.path.join(self.COMMON_DIR, "vision_utils.py") + ] + + return code_path + + def _prepare_artifacts_dict(self) -> Dict: + """Prepare artifacts dict for MLflow model. + + :return: artifacts dict + :rtype: Dict + """ + artifacts_dict = { + VirchowMLflowLiterals.MODEL_DIR: self._model_dir + } + return artifacts_dict + + diff --git a/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/conda.yaml b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/conda.yaml new file mode 100644 index 0000000000..43b178be2a --- /dev/null +++ b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/conda.yaml @@ -0,0 +1,17 @@ +channels: +- conda-forge +dependencies: +- python=3.10.14 +- pip<=24.0 +- pip: + - mlflow==2.13.2 + - cffi==1.16.0 + - cloudpickle==2.2.1 + - numpy==1.23.5 + - pandas==2.2.2 + - pyyaml==6.0.1 + - requests==2.32.3 + - timm==1.0.9,>=0.9.11 + - torch>2 + - pillow>=10 +name: mlflow-env diff --git a/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/config.py b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/config.py new file mode 100644 index 0000000000..03815582bf --- /dev/null +++ b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/config.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Common Config.""" + +from enum import Enum + +from mlflow.types import DataType + + +class _CustomEnum(Enum): + @classmethod + def has_value(cls, value): + return value in cls._value2member_map_ + +class MLflowSchemaLiterals: + """MLflow model signature related schema.""" + + INPUT_COLUMN_IMAGE_DATA_TYPE = DataType.binary + INPUT_COLUMN_IMAGE = "image" + INPUT_COLUMN_TEXT_DATA_TYPE = DataType.string + INPUT_COLUMN_TEXT = "text" + OUTPUT_COLUMN_DATA_TYPE = DataType.string + OUTPUT_COLUMN_PROBS = "probs" + OUTPUT_COLUMN_LABELS = "labels" + OUTPUT_COLUMN_IMAGE_FEATURES = "image_features" + OUTPUT_COLUMN_TEXT_FEATURES = "text_features" + + +class MLflowLiterals: + """MLflow export related literals.""" + + MODEL_DIR = "model_dir" diff --git a/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/virchow_mlflow_model_wrapper.py b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/virchow_mlflow_model_wrapper.py new file mode 100644 index 0000000000..36af9ab908 --- /dev/null +++ b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/virchow_mlflow_model_wrapper.py @@ -0,0 +1,71 @@ +import json + +import mlflow.pyfunc +import timm +import torch +import pandas as pd +import io +from PIL import Image +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform +from timm.layers import SwiGLUPacked +from config import MLflowSchemaLiterals +import logging +logger = logging.getLogger("mlflow") # Set log level to debugging +logger.setLevel(logging.DEBUG) + +class VirchowModelWrapper(mlflow.pyfunc.PythonModel): + def load_context(self, context): + config_path = context.artifacts["config_path"] + checkpoint_path = context.artifacts["checkpoint_path"] + # config = json.loads(config_path.read_text()) + with open(config_path) as f: + config = json.load(f) + self.model = timm.create_model( + model_name="vit_huge_patch14_224", + checkpoint_path=checkpoint_path, + mlp_layer=SwiGLUPacked, + act_layer=torch.nn.SiLU, + pretrained_cfg=config["pretrained_cfg"], + **config["model_args"] + ) + self.model.eval() + self.transforms = create_transform( + **resolve_data_config(self.model.pretrained_cfg, model=self.model) + ) + + # def predict(self, image_input_path: str, params: dict = None): + def predict(self, context: mlflow.pyfunc.PythonModelContext, input_data: pd.DataFrame, params: pd.DataFrame) -> pd.DataFrame: + + from vision_utils import process_image + pil_images = [ + Image.open(io.BytesIO(process_image(image))) + for image in input_data[MLflowSchemaLiterals.INPUT_COLUMN_IMAGE] + ] + # image = input_data["image"] + # pil_image = Image.open(io.BytesIO(process_image(pil_images[0]))) + pil_image = self.transforms(pil_images[0]).unsqueeze(0) # size: 1 x 3 x 224 x 224 + + device_type = params.get("device_type", "cuda") + to_half_precision = params.get("to_half_precision", False) + + with torch.inference_mode(), torch.autocast( + device_type=device_type, dtype=torch.float16 + ): + output = self.model(pil_image) # size: 1 x 257 x 1280 + + class_token = output[:, 0] # size: 1 x 1280 + patch_tokens = output[:, 1:] # size: 1 x 256 x 1280 + + # use the class token only as the embedding + # size: 1 x 1280 + embedding = class_token + + # the model output will be fp32 because the final operation is a LayerNorm that is ran in mixed precision + # optionally, you can convert the embedding to fp16 for efficiency in downstream use + if to_half_precision: + embedding = embedding.to(torch.float16) + + df_result = pd.DataFrame() + df_result['output'] = embedding.tolist() + return df_result \ No newline at end of file diff --git a/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/vision_utils.py b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/vision_utils.py new file mode 100644 index 0000000000..a0af9f0d17 --- /dev/null +++ b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/virchow/vision_utils.py @@ -0,0 +1,234 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Helper utils for vision Mlflow models.""" + +import base64 +import io +import logging +import os +import re +import requests +import uuid + +import PIL +import pandas as pd +import numpy as np +import torch + +from ast import literal_eval +from PIL import Image, UnidentifiedImageError +from typing import Union + + +logger = logging.getLogger(__name__) + +# Uncomment the following line for mlflow debug mode +# logging.getLogger("mlflow").setLevel(logging.DEBUG) + + +def save_image(output_folder: str, img: PIL.Image.Image, format: str) -> str: + """ + Save image in a folder designated for batch output and return image file path. + + :param output_folder: directory path where we need to save files + :type output_folder: str + :param img: image object + :type img: PIL.Image.Image + :param format: format to save image + :type format: str + :return: file name of image. + :rtype: str + """ + filename = f"image_{uuid.uuid4()}.{format.lower()}" + img.save(os.path.join(output_folder, filename), format=format) + return filename + + +def get_pil_image(image: bytes) -> PIL.Image.Image: + """ + Convert image bytes to PIL image. + + :param image: image bytes + :type image: bytes + :return: PIL image object + :rtype: PIL.Image.Image + """ + try: + return Image.open(io.BytesIO(image)) + except UnidentifiedImageError as e: + logger.error("Invalid image format. Please use base64 encoding for input images.") + raise e + + +def image_to_base64(img: PIL.Image.Image, format: str) -> str: + """ + Convert image into Base64 encoded string. + + :param img: image object + :type img: PIL.Image.Image + :param format: image format + :type format: str + :return: base64 encoded string + :rtype: str + """ + buffered = io.BytesIO() + img.save(buffered, format=format) + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + return img_str + + +def process_image(image: Union[str, bytes]) -> bytes: + """Process image. + + If input image is in bytes format, return it as it is. + If input image is in base64 string format, decode it to bytes. + If input image is in url format, download it and return bytes. + https://github.com/mlflow/mlflow/blob/master/examples/flower_classifier/image_pyfunc.py + + :param image: image in base64 string format or url or bytes. + :type image: string or bytes + :return: decoded image. + :rtype: bytes + """ + if isinstance(image, bytes): + return image + elif isinstance(image, str): + if _is_valid_url(image): + try: + response = requests.get(image) + response.raise_for_status() # Raise exception in case of unsuccessful response code. + image = response.content + return image + except requests.exceptions.RequestException as ex: + raise ValueError(f"Unable to retrieve image from url string due to exception: {ex}") + else: + try: + return base64.b64decode(image) + except ValueError: + raise ValueError( + "The provided image string cannot be decoded. " "Expected format is base64 string or url string." + ) + else: + raise ValueError( + f"Image received in {type(image)} format which is not supported. " + "Expected format is bytes, base64 string or url string." + ) + + +def process_image_pandas_series(image_pandas_series: pd.Series) -> pd.Series: + """Process image in Pandas series form. + + If input image is in bytes format, return it as it is. + If input image is in base64 string format, decode it to bytes. + If input image is in url format, download it and return bytes. + https://github.com/mlflow/mlflow/blob/master/examples/flower_classifier/image_pyfunc.py + + :param img: pandas series with image in base64 string format or url or bytes. + :type img: pd.Series + :return: decoded image in pandas series format. + :rtype: Pandas Series + """ + image = image_pandas_series[0] + return pd.Series(process_image(image)) + + +def _is_valid_url(text: str) -> bool: + """Check if text is url or base64 string. + + :param text: text to validate + :type text: str + :return: True if url else false + :rtype: bool + """ + regex = ( + "((http|https)://)(www.)?" + + "[a-zA-Z0-9@:%._\\+~#?&//=\\-]" + + "{2,256}\\.[a-z]" + + "{2,6}\\b([-a-zA-Z0-9@:%" + + "._\\+~#?&//=]*)" + ) + p = re.compile(regex) + + # If the string is empty + # return false + if str is None: + return False + + # Return if the string + # matched the ReGex + if re.search(p, text): + return True + else: + return False + + +def get_current_device() -> torch.device: + """Get current cuda device. + + :return: current device + :rtype: torch.device + """ + # check if GPU is available + if torch.cuda.is_available(): + try: + # get the current device index + device_idx = torch.distributed.get_rank() + except RuntimeError as ex: + if "Default process group has not been initialized".lower() in str(ex).lower(): + device_idx = 0 + else: + logger.error(str(ex)) + raise ex + return torch.device(type="cuda", index=device_idx) + else: + return torch.device(type="cpu") + + +def string_to_nested_float_list(input_str: str) -> list: + """Convert string to nested list of floats. + + :return: string converted to nested list of floats + :rtype: list + """ + if input_str in ["null", "None", "", "nan", "NoneType", np.nan, None]: + return None + try: + # Use ast.literal_eval to safely evaluate the string into a list + nested_list = literal_eval(input_str) + + # Recursive function to convert all numbers in the nested list to floats + def to_floats(lst) -> list: + """ + Recursively convert all numbers in a nested list to floats. + + :param lst: nested list + :type lst: list + :return: nested list of floats + :rtype: list + """ + return [to_floats(item) if isinstance(item, list) else float(item) for item in lst] + + # Use the recursive function to process the nested list + return to_floats(nested_list) + except (ValueError, SyntaxError) as e: + # In case of an error during conversion, print an error message + print(f"Invalid input {input_str}: {e}, ignoring.") + return None + + +def bool_array_to_pil_image(bool_array: np.ndarray) -> PIL.Image.Image: + """Convert boolean array to PIL Image. + + :param bool_array: boolean array + :type bool_array: np.array + :return: PIL Image + :rtype: PIL.Image.Image + """ + # Convert boolean array to uint8 + uint8_array = bool_array.astype(np.uint8) * 255 + + # Create a PIL Image + pil_image = Image.fromarray(uint8_array) + + return pil_image