-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Ye, Xinyu <[email protected]>
- Loading branch information
1 parent
179b5da
commit 2587a29
Showing
14 changed files
with
239 additions
and
135 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import base64 | ||
import os | ||
import tempfile | ||
|
||
import torch | ||
from diffusers import DiffusionPipeline | ||
|
||
from comps import CustomLogger, OpeaComponent, OpeaComponentRegistry, SDInputs, SDOutputs, ServiceType | ||
|
||
logger = CustomLogger("opea") | ||
|
||
|
||
@OpeaComponentRegistry.register("OPEA_TEXT2IMAGE") | ||
class OpeaText2image(OpeaComponent): | ||
"""A specialized text2image component derived from OpeaComponent for text2image services. | ||
Attributes: | ||
client (AsyncInferenceClient): An instance of the async client for embedding generation. | ||
model_name (str): The name of the embedding model used. | ||
""" | ||
|
||
def __init__(self, name: str, description: str, config: dict = None): | ||
super().__init__(name, ServiceType.TEXT2IMAGE.name.lower(), description, config) | ||
|
||
# initialize model and tokenizer | ||
self.seed = config["seed"] | ||
model_name_or_path = config["model_name_or_path"] | ||
device = config["device"] | ||
if os.getenv("MODEL", None): | ||
model_name_or_path = os.getenv("MODEL") | ||
kwargs = {} | ||
if config["bf16"]: | ||
kwargs["torch_dtype"] = torch.bfloat16 | ||
if not config["token"]: | ||
config["token"] = os.getenv("HF_TOKEN") | ||
if device == "hpu": | ||
kwargs.update( | ||
{ | ||
"use_habana": True, | ||
"use_hpu_graphs": config["use_hpu_graphs"], | ||
"gaudi_config": "Habana/stable-diffusion", | ||
"token": config["token"], | ||
} | ||
) | ||
if "stable-diffusion-3" in model_name_or_path: | ||
from optimum.habana.diffusers import GaudiStableDiffusion3Pipeline | ||
|
||
self.pipe = GaudiStableDiffusion3Pipeline.from_pretrained( | ||
model_name_or_path, | ||
**kwargs, | ||
) | ||
elif "stable-diffusion" in model_name_or_path.lower() or "flux" in model_name_or_path.lower(): | ||
from optimum.habana.diffusers import AutoPipelineForText2Image | ||
|
||
self.pipe = AutoPipelineForText2Image.from_pretrained( | ||
model_name_or_path, | ||
**kwargs, | ||
) | ||
else: | ||
raise NotImplementedError( | ||
"Only support stable-diffusion, stable-diffusion-xl, stable-diffusion-3 and flux now, " | ||
+ f"model {model_name_or_path} not supported." | ||
) | ||
elif device == "cpu": | ||
self.pipe = DiffusionPipeline.from_pretrained(model_name_or_path, token=config["token"], **kwargs) | ||
else: | ||
raise NotImplementedError(f"Only support cpu and hpu device now, device {device} not supported.") | ||
logger.info("Stable Diffusion model initialized.") | ||
|
||
async def invoke(self, input: SDInputs) -> SDOutputs: | ||
"""Invokes the text2image service to generate image(s) for the provided input. | ||
Args: | ||
input (SDInputs): The input for text2image service, including prompt and optional parameters like num_images_per_prompt. | ||
Returns: | ||
SDOutputs: The response is a list of images. | ||
""" | ||
prompt = input.prompt | ||
num_images_per_prompt = input.num_images_per_prompt | ||
|
||
generator = torch.manual_seed(self.seed) | ||
images = self.pipe(prompt, generator=generator, num_images_per_prompt=num_images_per_prompt).images | ||
with tempfile.TemporaryDirectory() as image_path: | ||
results = [] | ||
for i, image in enumerate(images): | ||
save_path = os.path.join(image_path, f"image_{i+1}.png") | ||
image.save(save_path) | ||
with open(save_path, "rb") as f: | ||
bytes = f.read() | ||
b64_str = base64.b64encode(bytes).decode() | ||
results.append(b64_str) | ||
return SDOutputs(images=results) | ||
|
||
def check_health(self) -> bool: | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import argparse | ||
import os | ||
import time | ||
|
||
from integrations.opea import OpeaText2image | ||
|
||
from comps import ( | ||
CustomLogger, | ||
OpeaComponentLoader, | ||
SDInputs, | ||
SDOutputs, | ||
ServiceType, | ||
opea_microservices, | ||
register_microservice, | ||
register_statistics, | ||
statistics_dict, | ||
) | ||
|
||
logger = CustomLogger("opea_text2image_microservice") | ||
|
||
|
||
@register_microservice( | ||
name="opea_service@text2image", | ||
service_type=ServiceType.TEXT2IMAGE, | ||
endpoint="/v1/text2image", | ||
host="0.0.0.0", | ||
port=9379, | ||
input_datatype=SDInputs, | ||
output_datatype=SDOutputs, | ||
) | ||
@register_statistics(names=["opea_service@text2image"]) | ||
async def text2image(input: SDInputs): | ||
start = time.time() | ||
try: | ||
# Use the loader to invoke the active component | ||
results = await loader.invoke(input) | ||
statistics_dict["opea_service@text2image"].append_latency(time.time() - start, None) | ||
return results | ||
except Exception as e: | ||
logger.error(f"Error during text2image invocation: {e}") | ||
raise | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_name_or_path", type=str, default="stabilityai/stable-diffusion-3-medium-diffusers") | ||
parser.add_argument("--use_hpu_graphs", default=False, action="store_true") | ||
parser.add_argument("--device", type=str, default="cpu") | ||
parser.add_argument("--token", type=str, default=None) | ||
parser.add_argument("--seed", type=int, default=42) | ||
parser.add_argument("--bf16", action="store_true") | ||
|
||
args = parser.parse_args() | ||
|
||
text2image_component_name = os.getenv("TEXT2IMAGE_COMPONENT_NAME", "OPEA_TEXT2IMAGE") | ||
# Initialize OpeaComponentLoader | ||
loader = OpeaComponentLoader( | ||
text2image_component_name, | ||
description=f"OPEA TEXT2IMAGE Component: {text2image_component_name}", | ||
config=args.__dict__, | ||
) | ||
|
||
logger.info("Text2image server started.") | ||
opea_microservices["opea_service@text2image"].start() |
File renamed without changes.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.