diff --git a/runner/app/main.py b/runner/app/main.py index 0ba3ee3a..62647487 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -23,7 +23,9 @@ async def lifespan(app: FastAPI): pipeline = os.environ["PIPELINE"] model_id = os.environ["MODEL_ID"] - app.pipeline = load_pipeline(pipeline, model_id) + task = os.environ["TASK"] if pipeline == "image-to-image-generic" else None + + app.pipeline = load_pipeline(pipeline, model_id, task) app.include_router(load_route(pipeline)) app.hardware_info_service.log_gpu_compute_info() @@ -34,7 +36,7 @@ async def lifespan(app: FastAPI): logger.info("Shutting down") -def load_pipeline(pipeline: str, model_id: str) -> any: +def load_pipeline(pipeline: str, model_id: str, task: str) -> any: match pipeline: case "text-to-image": from app.pipelines.text_to_image import TextToImagePipeline @@ -78,6 +80,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any: from app.pipelines.text_to_speech import TextToSpeechPipeline return TextToSpeechPipeline(model_id) + case "image-to-image-generic": + from app.pipelines.image_to_image_generic import ImageToImageGenericPipeline + + return ImageToImageGenericPipeline(model_id, task) case _: raise EnvironmentError( f"{pipeline} is not a valid pipeline for model {model_id}" @@ -128,6 +134,10 @@ def load_route(pipeline: str) -> any: from app.routes import text_to_speech return text_to_speech.router + case "image-to-image-generic": + from app.routes import image_to_image_generic + + return image_to_image_generic.router case _: raise EnvironmentError(f"{pipeline} is not a valid pipeline") diff --git a/runner/app/pipelines/image_to_image_generic.py b/runner/app/pipelines/image_to_image_generic.py new file mode 100644 index 00000000..681ba2c8 --- /dev/null +++ b/runner/app/pipelines/image_to_image_generic.py @@ -0,0 +1,301 @@ +import json +import logging +import numpy as np +import os +from enum import Enum +from typing import List, Optional, Tuple + +import PIL +import torch +from diffusers import ( + AutoPipelineForInpainting, + ControlNetModel, + StableDiffusionXLControlNetPipeline, + StableDiffusionXLInpaintPipeline, + EulerAncestralDiscreteScheduler, + AutoencoderKL, +) +from huggingface_hub import file_download +from PIL import Image, ImageOps + +from app.pipelines.base import Pipeline +from app.pipelines.utils import ( + LoraLoader, + get_model_dir, + get_torch_device, +) +from app.utils.errors import InferenceError + +logger = logging.getLogger(__name__) + + +class TaskType(Enum): + """Enumeration for task types.""" + + INPAINTING = "inpainting" + OUTPAINTING = "outpainting" + SKETCH_TO_IMAGE = "sketch_to_image" + + @classmethod + def list(cls): + return [task.value for task in cls] + + +class ImageToImageGenericPipeline(Pipeline): + def __init__(self, model_id: str, task: str): + kwargs = {"cache_dir": get_model_dir(), "torch_dtype": torch.float16} + torch_device = get_torch_device() + + # Check if the model_id is a dictionary in string format in the default value case of model_id on go livepeer side + if model_id.startswith("{") and model_id.endswith("}"): + try: + # Perform json parsing of the string into a dictionary + model_id_dict = json.loads(model_id.replace("'", '"')) # Replace single quotes to make it JSON compliant + if isinstance(model_id_dict, dict) and task in model_id_dict: + model_id = model_id_dict[task] + except json.JSONDecodeError: + pass + + folder_name = file_download.repo_folder_name( + repo_id=model_id, repo_type="model" + ) + folder_path = os.path.join(get_model_dir(), folder_name) + # Load the fp16 variant if fp16 'safetensors' files are present in the cache. + # NOTE: Exception for SDXL-Lightning model: despite having fp16 'safetensors' + # files, they are not named according to the standard convention. + has_fp16_variant = any( + ".fp16.safetensors" in fname + for _, _, files in os.walk(folder_path) + for fname in files + ) + if torch_device.type != "cpu" and has_fp16_variant: + logger.info( + "ImageToImageGenericPipeline loading fp16 variant for %s", model_id + ) + + kwargs["torch_dtype"] = torch.float16 + kwargs["variant"] = "fp16" + + if task not in TaskType.list(): + raise ValueError(f"Unsupported task: {task}") + + self.task = task + + # Initialize pipelines based on task + if self.task == TaskType.INPAINTING.value: + self.pipeline = AutoPipelineForInpainting.from_pretrained( + model_id, **kwargs + ) + self.pipeline.enable_model_cpu_offload() + + elif self.task == TaskType.OUTPAINTING.value: + self.controlnet = ( + ControlNetModel.from_pretrained( + model_id, torch_dtype=torch.float16, variant="fp16" + ), + ) + self.vae = AutoencoderKL.from_pretrained( + "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 + ) + self.pipeline_stage1 = StableDiffusionXLControlNetPipeline.from_pretrained( + "SG161222/RealVisXL_V4.0", + controlnet=self.controlnet, + vae=self.vae, + safety_checker=None, + **kwargs, + ) + self.pipeline_stage1.enable_model_cpu_offload() + self.pipeline_stage2 = StableDiffusionXLInpaintPipeline.from_pretrained( + "OzzyGT/RealVisXL_V4.0_inpainting", vae=self.vae, **kwargs + ) + self.pipeline_stage1.enable_model_cpu_offload() + + elif self.task == TaskType.SKETCH_TO_IMAGE.value: + self.controlnet = ControlNetModel.from_pretrained(model_id, **kwargs) + self.vae = AutoencoderKL.from_pretrained( + "madebyollin/sdxl-vae-fp16-fix", **kwargs + ) + eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler" + ) + self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + controlnet=self.controlnet, + vae=self.vae, + safety_checker=None, + scheduler=eulera_scheduler, + **kwargs, + ) + self.pipeline.enable_model_cpu_offload() + + self._lora_loader = LoraLoader(self.pipeline) + + if self.task == TaskType.OUTPAINTING.value: + self._lora_loader1 = LoraLoader(self.pipeline_stage1) + self._lora_loader2 = LoraLoader(self.pipeline_stage2) + + def __call__( + self, + prompt: List[str], + image: PIL.Image.Image, + mask_image: Optional[PIL.Image.Image] = None, + **kwargs, + ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: + # Handle num_inference_steps and other model-specific settings + if "num_inference_steps" in kwargs and ( + kwargs["num_inference_steps"] is None or any(x < 1 for x in kwargs["num_inference_steps"]) + ): + logger.warning("Invalid num_inference_steps found. Deleting it from kwargs.") + del kwargs["num_inference_steps"] + + # Extract parameters from kwargs + seed = kwargs.pop("seed", None) + safety_check = kwargs.pop("safety_check", True) + loras_json = kwargs.pop("loras", "") + guidance_scale = kwargs.pop("guidance_scale", None) + num_inference_steps = kwargs.pop("num_inference_steps", None) + controlnet_conditioning_scale = kwargs.pop( + "controlnet_conditioning_scale", None + ) + control_guidance_end = kwargs.pop("control_guidance_end", None) + strength = kwargs.pop("strength", None) + + if len(prompt) == 1: + prompt = prompt[0] + + # Handle seed initialization for reproducibility + if seed is not None: + if isinstance(seed, int): + kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed( + seed + ) + elif isinstance(seed, list): + kwargs["generator"] = [ + torch.Generator(get_torch_device()).manual_seed(s) for s in seed + ] + + # Dynamically (un)load LoRas. + if not loras_json: + if self.task == TaskType.OUTPAINTING.value: + self._lora_loader1.disable_loras() + self._lora_loader2.disable_loras() + else: + self._lora_loader.disable_loras() # Assuming _lora_loader is defined elsewhere + else: + if self.task == TaskType.OUTPAINTING.value: + self._lora_loader1.load_loras(loras_json) + self._lora_loader2.load_loras(loras_json) + else: + self._lora_loader.load_loras( + loras_json + ) # Assuming _lora_loader is defined elsewhere + + # Ensure proper inference configuration based on model + if self.task == TaskType.INPAINTING.value: + if mask_image is None: + raise ValueError("Mask image is required for inpainting.") + try: + outputs = self.pipeline( + prompt=prompt, + image=image, + mask_image=mask_image, + guidance_scale=guidance_scale[0], + strength=strength, + **kwargs, + ).images[0] + except torch.cuda.OutOfMemoryError as e: + raise e + except Exception as e: + raise InferenceError(original_exception=e) + elif self.task == TaskType.OUTPAINTING.value: + try: + resized_image, white_bg_image = self._scale_and_paste(image) + temp_image = self.pipeline_stage1( + prompt=prompt[0], + image=white_bg_image, + guidance_scale=guidance_scale[0], + num_inference_steps=num_inference_steps[0], + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_end=control_guidance_end, + **kwargs, + ).images[0] + + x = (1024 - resized_image.width) // 2 + y = (1024 - resized_image.height) // 2 + temp_image.paste(resized_image, (x, y), resized_image) + + mask = Image.new("L", temp_image.size) + mask.paste(resized_image.split()[3], (x, y)) + mask = ImageOps.invert(mask) + final_mask = mask.point(lambda p: p > 128 and 255) + mask_blurred = self.pipeline_stage2.mask_processor.blur( + final_mask, blur_factor=20 + ) + + outputs = self.pipeline_stage2( + prompt[1], + image=temp_image, + mask_image=mask_blurred, + strength=strength, + guidance_scale=guidance_scale[1], + num_inference_steps=num_inference_steps[1], + **kwargs, + ).images[0] + + x = (1024 - resized_image.width) // 2 + y = (1024 - resized_image.height) // 2 + outputs.paste(resized_image, (x, y), resized_image) + except torch.cuda.OutOfMemoryError as e: + raise e + except Exception as e: + raise InferenceError(original_exception=e) + elif self.task == TaskType.SKETCH_TO_IMAGE.value: + try: + # must resize to 1024*1024 or same resolution bucket to get the best performance + width, height = image.size + ratio = np.sqrt(1024.0 * 1024.0 / (width * height)) + new_width, new_height = int(width * ratio), int(height * ratio) + image = image.resize((new_width, new_height)) + outputs = self.pipeline( + prompt=prompt, + image=image, + num_inference_steps=num_inference_steps[0], + controlnet_conditioning_scale=controlnet_conditioning_scale, + **kwargs, + ).images[0] + except torch.cuda.OutOfMemoryError as e: + raise e + except Exception as e: + raise InferenceError(original_exception=e) + + # Safety check for NSFW content + if safety_check: + _, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images) + else: + has_nsfw_concept = [None] * len(outputs.images) + + return outputs, has_nsfw_concept # Return the first image in the output list + + @staticmethod + def _scale_and_paste( + original_image: PIL.Image.Image, + ) -> Tuple[PIL.Image.Image, PIL.Image.Image]: + """Resize and paste the original image onto a 1024x1024 white canvas.""" + aspect_ratio = original_image.width / original_image.height + if original_image.width > original_image.height: + new_width = 1024 + new_height = round(new_width / aspect_ratio) + else: + new_height = 1024 + new_width = round(new_height * aspect_ratio) + + resized_original = original_image.resize((new_width, new_height), Image.LANCZOS) + white_background = Image.new("RGBA", (1024, 1024), "white") + x = (1024 - new_width) // 2 + y = (1024 - new_height) // 2 + white_background.paste(resized_original, (x, y), resized_original) + + return resized_original, white_background + + def __str__(self) -> str: + return f"ImageToImageGenericPipeline task={self.task}" diff --git a/runner/app/routes/image_to_image_generic.py b/runner/app/routes/image_to_image_generic.py new file mode 100644 index 00000000..7d77f7bd --- /dev/null +++ b/runner/app/routes/image_to_image_generic.py @@ -0,0 +1,258 @@ +import base64 +import logging +import numpy as np +import os +import random +import zlib +from typing import Annotated, Dict, Tuple, Union + +import torch +from fastapi import APIRouter, Depends, File, Form, UploadFile, status +from fastapi.responses import JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from PIL import Image, ImageFile + +from app.dependencies import get_pipeline +from app.pipelines.base import Pipeline +from app.routes.utils import ( + HTTPError, + ImageResponse, + handle_pipeline_exception, + http_error, + image_to_data_url, + json_str_to_np_array, +) + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +router = APIRouter() + +logger = logging.getLogger(__name__) + + +# Pipeline specific error handling configuration. +PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { + # Specific error types. + "OutOfMemoryError": ( + "Out of memory error. Try reducing input image resolution.", + status.HTTP_500_INTERNAL_SERVER_ERROR, + ) +} + +RESPONSES = { + status.HTTP_200_OK: { + "content": { + "application/json": { + "schema": { + "x-speakeasy-name-override": "data", + } + } + }, + }, + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + +# TODO: Make model_id and other None properties optional once Go codegen tool supports +# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 +@router.post( + "/image-to-image-generic", + response_model=ImageResponse, + responses=RESPONSES, + description="Apply image transformations to a provided image according to the choice of tasks, i.e., outpainting, inpainting, sketch2image.", + operation_id="genImageToImageGeneric", + summary="Image To Image Generic", + tags=["generate"], + openapi_extra={"x-speakeasy-name-override": "imageToImageGeneric"}, +) +@router.post( + "/image-to-image-generic/", + response_model=ImageResponse, + responses=RESPONSES, + include_in_schema=False, +) +async def image_to_image_generic( + prompt: Annotated[ + str, + Form(description="Text prompt(s) to guide image generation."), + ], + image: Annotated[ + UploadFile, + File(description="Uploaded image to modify with the pipeline."), + ], + mask_image: Annotated[ + str, + Form( + description=( + "Mask image to determine which regions of an image to fill in" + "for inpainting task with the form HxW." + ) + ), + ] = None, + model_id: Annotated[ + str, + Form(description="Hugging Face model ID used for image generation."), + ] = "", + loras: Annotated[ + str, + Form( + description=( + "A LoRA (Low-Rank Adaptation) model and its corresponding weight for " + 'image generation. Example: { "latent-consistency/lcm-lora-sdxl": ' + '1.0, "nerijs/pixel-art-xl": 1.2}.' + ) + ), + ] = "", + strength: Annotated[ + float, + Form( + description=( + "Degree of transformation applied to the reference image (0 to 1)." + ) + ), + ] = 0.8, + guidance_scale: Annotated[ + str, + Form( + description=( + "Encourages model to generate images closely linked to the text prompt " + "(higher values may reduce image quality)." + ) + ), + ] = "[6.5, 10.0]", + negative_prompt: Annotated[ + str, + Form( + description=( + "Text prompt(s) to guide what to exclude from image generation. " + "Ignored if guidance_scale < 1." + ) + ), + ] = "", + safety_check: Annotated[ + bool, + Form( + description=( + "Perform a safety check to estimate if generated images could be " + "offensive or harmful." + ) + ), + ] = True, + seed: Annotated[int, Form(description="Seed for random number generation.")] = None, + num_inference_steps: Annotated[ + str, + Form( + description=( + "Number of denoising steps. More steps usually lead to higher quality " + "images but slower inference. Modulated by strength." + ) + ), + ] = "[30, 25]", + controlnet_conditioning_scale: Annotated[ + float, + Form( + description=( + "Determines how much weight to assign to the conditioning inputs." + ) + ), + ] = 0.5, + control_guidance_end: Annotated[ + float, + Form( + description=( + "The percentage of total steps at which the ControlNet stops applying." + ) + ), + ] = 0.9, + num_images_per_prompt: Annotated[ + int, + Form(description="Number of images to generate per prompt."), + ] = 1, + pipeline: Pipeline = Depends(get_pipeline), + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), +): + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token."), + ) + + if model_id != "" and model_id != pipeline.model_id: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error( + f"pipeline configured with {pipeline.model_id} but called with " + f"{model_id}." + ), + ) + + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + seeds = [seed + i for i in range(num_images_per_prompt)] + + image = Image.open(image.file).convert("RGB") + + try: + prompt = json_str_to_np_array(prompt, var_name="prompt") + guidance_scale = json_str_to_np_array(guidance_scale, var_name="guidance_scale") + num_inference_steps = json_str_to_np_array( + num_inference_steps, var_name="num_inference_steps" + ) + if mask_image: + mask_image = base64.b64decode(mask_image) + mask_image = zlib.decompress(mask_image) + mask_image = np.frombuffer(mask_image, dtype=np.uint8) + mask_image = mask_image.reshape((image.size)) + mask_image = Image.fromarray(mask_image, mode="L") + except ValueError as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error(str(e)), + ) + + # TODO: Process one image at a time to avoid CUDA OEM errors. Can be removed again + # once LIV-243 and LIV-379 are resolved. + images = [] + has_nsfw_concept = [] + for seed in seeds: + try: + imgs, nsfw_checks = pipeline( + prompt=prompt, + image=image, + mask_image=mask_image, + strength=strength, + loras=loras, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + safety_check=safety_check, + seed=seed, + num_images_per_prompt=1, + num_inference_steps=num_inference_steps, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_end=control_guidance_end, + ) + except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + # TODO: Investigate why not all VRAM memory is cleared. + torch.cuda.empty_cache() + logger.error(f"ImageToImageGenericPipeline pipeline error: {e}") + return handle_pipeline_exception( + e, + default_error_message="Image-to-image-generic pipeline error.", + custom_error_config=PIPELINE_ERROR_CONFIG, + ) + images.extend(imgs) + has_nsfw_concept.extend(nsfw_checks) + + # TODO: Return None once Go codegen tool supports optional properties + # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 + output_images = [ + {"url": image_to_data_url(img), "seed": sd, "nsfw": nsfw or False} + for img, sd, nsfw in zip(images, seeds, has_nsfw_concept) + ] + + return {"images": output_images} diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 6ebb88ef..8b4c6ece 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -80,6 +80,13 @@ function download_all_models() { # Custom pipeline models. huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models + # Download image-to-image-generic models. + huggingface-cli download madebyollin/sdxl-vae-fp16-fix --include "*.safetensors" "*.json" --cache-dir models + huggingface-cli download OzzyGT/RealVisXL_V4.0_inpainting --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models + huggingface-cli download kandinsky-community/kandinsky-2-2-decoder-inpaint --include "*.safetensors" "*.json" --cache-dir models + huggingface-cli download destitech/controlnet-inpaint-dreamer-sdxl --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models + huggingface-cli download xinsir/controlnet-scribble-sdxl-1.0 --include "*.safetensors" "*.json" --cache-dir models + download_live_models } diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index 4da1dda7..dffa39d1 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -509,6 +509,54 @@ paths: security: - HTTPBearer: [] x-speakeasy-name-override: textToSpeech + /image-to-image-generic: + post: + tags: + - generate + summary: Image To Image Generic + description: Apply image transformations to a provided image according to the choice of tasks, i.e., outpainting, inpainting, sketch2image. + operationId: genImageToImageGeneric + requestBody: + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_genImageToImageGeneric' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ImageResponse' + x-speakeasy-name-override: data + '400': + description: Bad Request + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '500': + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + security: + - HTTPBearer: [] + x-speakeasy-name-override: imageToImageGeneric components: schemas: APIError: @@ -807,6 +855,89 @@ components: - image - model_id title: Body_genUpscale + Body_genImageToImageGeneric: + properties: + prompt: + type: string + title: Prompt + description: Text prompt(s) to guide image generation. + image: + type: string + format: binary + title: Image + description: Uploaded image to modify with the pipeline. + mask_image: + type: string + title: Mask Image + description: Mask image to determine which regions of an image to fill in for + inpainting task with the form HxW. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for image generation. + default: '' + loras: + type: string + title: Loras + description: 'A LoRA (Low-Rank Adaptation) model and its corresponding weight + for image generation. Example: { "latent-consistency/lcm-lora-sdxl": 1.0, + "nerijs/pixel-art-xl": 1.2}.' + default: '' + strength: + type: number + title: Strength + description: Degree of transformation applied to the reference image (0 + to 1). + default: 0.8 + guidance_scale: + type: string + title: Guidance Scale + description: Encourages model to generate images closely linked to the text + prompt (higher values may reduce image quality). + default: '[6.5, 10.0]' + negative_prompt: + type: string + title: Negative Prompt + description: Text prompt(s) to guide what to exclude from image generation. + Ignored if guidance_scale < 1. + default: '' + safety_check: + type: boolean + title: Safety Check + description: Perform a safety check to estimate if generated images could + be offensive or harmful. + default: true + seed: + type: integer + title: Seed + description: Seed for random number generation. + num_inference_steps: + type: string + title: Num Inference Steps + description: Number of denoising steps. More steps usually lead to higher + quality images but slower inference. Modulated by strength. + default: '[30, 25]' + controlnet_conditioning_scale: + type: number + title: Controlnet Conditioning Scale + description: Determines how much weight to assign to the conditioning inputs. + default: 0.5 + control_guidance_end: + type: number + title: Control Guidance End + description: The percentage of total steps at which the ControlNet stops applying. + default: 0.9 + num_images_per_prompt: + type: integer + title: Num Images Per Prompt + description: Number of images to generate per prompt. + default: 1 + type: object + required: + - prompt + - image + - model_id + title: Body_genImageToImageGeneric Chunk: properties: timestamp: diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index 94072e04..af1ccf7d 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -20,6 +20,7 @@ text_to_image, text_to_speech, upscale, + image_to_image_generic, ) logging.basicConfig( @@ -114,6 +115,7 @@ def write_openapi(fname: str, entrypoint: str = "runner"): app.include_router(image_to_text.router) app.include_router(live_video_to_video.router) app.include_router(text_to_speech.router) + app.include_router(image_to_image_generic.router) logger.info(f"Generating OpenAPI schema for '{entrypoint}' entrypoint...") openapi = get_openapi( @@ -164,8 +166,8 @@ def write_openapi(fname: str, entrypoint: str = "runner"): parser.add_argument( "--entrypoint", type=str, - choices=["gateway","runner"], - default=["gateway","runner"], + choices=["gateway", "runner"], + default=["gateway", "runner"], nargs="+", help=( "The entrypoint to generate the OpenAPI schema for, options are 'runner' " diff --git a/runner/openapi.yaml b/runner/openapi.yaml index 469f1b8a..89a67987 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -509,6 +509,54 @@ paths: security: - HTTPBearer: [] x-speakeasy-name-override: textToSpeech + /image-to-image-generic: + post: + tags: + - generate + summary: Image To Image Generic + description: Apply image transformations to a provided image according to the choice of tasks, i.e., outpainting, inpainting, sketch2image. + operationId: genImageToImageGeneric + requestBody: + content: + multipart/form-data: + schema: + $ref: '#/components/schemas/Body_genImageToImageGeneric' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ImageResponse' + x-speakeasy-name-override: data + '400': + description: Bad Request + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '500': + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + security: + - HTTPBearer: [] + x-speakeasy-name-override: imageToImageGeneric /health: get: summary: Health @@ -839,6 +887,89 @@ components: - prompt - image title: Body_genUpscale + Body_genImageToImageGeneric: + properties: + prompt: + type: string + title: Prompt + description: Text prompt(s) to guide image generation. + image: + type: string + format: binary + title: Image + description: Uploaded image to modify with the pipeline. + mask_image: + type: string + title: Mask Image + description: Mask image to determine which regions of an image to fill in for + inpainting task with the form HxW. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for image generation. + default: '' + loras: + type: string + title: Loras + description: 'A LoRA (Low-Rank Adaptation) model and its corresponding weight + for image generation. Example: { "latent-consistency/lcm-lora-sdxl": 1.0, + "nerijs/pixel-art-xl": 1.2}.' + default: '' + strength: + type: number + title: Strength + description: Degree of transformation applied to the reference image (0 + to 1). + default: 0.8 + guidance_scale: + type: string + title: Guidance Scale + description: Encourages model to generate images closely linked to the text + prompt (higher values may reduce image quality). + default: '[6.5, 10.0]' + negative_prompt: + type: string + title: Negative Prompt + description: Text prompt(s) to guide what to exclude from image generation. + Ignored if guidance_scale < 1. + default: '' + safety_check: + type: boolean + title: Safety Check + description: Perform a safety check to estimate if generated images could + be offensive or harmful. + default: true + seed: + type: integer + title: Seed + description: Seed for random number generation. + num_inference_steps: + type: string + title: Num Inference Steps + description: Number of denoising steps. More steps usually lead to higher + quality images but slower inference. Modulated by strength. + default: '[30, 25]' + controlnet_conditioning_scale: + type: number + title: Controlnet Conditioning Scale + description: Determines how much weight to assign to the conditioning inputs. + default: 0.5 + control_guidance_end: + type: number + title: Control Guidance End + description: The percentage of total steps at which the ControlNet stops applying. + default: 0.9 + num_images_per_prompt: + type: integer + title: Num Images Per Prompt + description: Number of images to generate per prompt. + default: 1 + type: object + required: + - prompt + - image + - model_id + title: Body_genImageToImageGeneric Chunk: properties: timestamp: diff --git a/worker/docker.go b/worker/docker.go index 1b30f350..6ece9586 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -44,16 +44,17 @@ var maxHealthCheckFailures = 2 // This only works right now on a single GPU because if there is another container // using the GPU we stop it so we don't have to worry about having enough ports var containerHostPorts = map[string]string{ - "text-to-image": "8000", - "image-to-image": "8100", - "image-to-video": "8200", - "upscale": "8300", - "audio-to-text": "8400", - "llm": "8500", - "segment-anything-2": "8600", - "image-to-text": "8700", - "text-to-speech": "8800", - "live-video-to-video": "8900", + "text-to-image": "8000", + "image-to-image": "8100", + "image-to-video": "8200", + "upscale": "8300", + "audio-to-text": "8400", + "llm": "8500", + "segment-anything-2": "8600", + "image-to-text": "8700", + "text-to-speech": "8800", + "live-video-to-video": "8900", + "image-to-image-generic": "9000", } // Mapping for per pipeline container images. diff --git a/worker/multipart.go b/worker/multipart.go index bc70ba8f..565fe589 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -363,3 +363,93 @@ func NewImageToTextMultipartWriter(w io.Writer, req GenImageToTextMultipartReque return mw, nil } + +func NewImageToImageGenericMultipartWriter(w io.Writer, req GenImageToImageGenericMultipartRequestBody) (*multipart.Writer, error) { + mw := multipart.NewWriter(w) + writer, err := mw.CreateFormFile("image", req.Image.Filename()) + if err != nil { + return nil, err + } + imageSize := req.Image.FileSize() + imageRdr, err := req.Image.Reader() + if err != nil { + return nil, err + } + copied, err := io.Copy(writer, imageRdr) + if err != nil { + return nil, err + } + if copied != imageSize { + return nil, fmt.Errorf("failed to copy image to multipart request imageBytes=%v copiedBytes=%v", imageSize, copied) + } + + if err := mw.WriteField("prompt", req.Prompt); err != nil { + return nil, err + } + if req.ModelId != nil { + if err := mw.WriteField("model_id", *req.ModelId); err != nil { + return nil, err + } + } + if req.MaskImage != nil { + if err := mw.WriteField("mask_image", *req.MaskImage); err != nil { + return nil, err + } + } + if req.Loras != nil { + if err := mw.WriteField("loras", *req.Loras); err != nil { + return nil, err + } + } + if req.Strength != nil { + if err := mw.WriteField("strength", fmt.Sprintf("%f", *req.Strength)); err != nil { + return nil, err + } + } + if req.GuidanceScale != nil { + if err := mw.WriteField("guidance_scale", *req.GuidanceScale); err != nil { + return nil, err + } + } + if req.NegativePrompt != nil { + if err := mw.WriteField("negative_prompt", *req.NegativePrompt); err != nil { + return nil, err + } + } + if req.SafetyCheck != nil { + if err := mw.WriteField("safety_check", strconv.FormatBool(*req.SafetyCheck)); err != nil { + return nil, err + } + } + if req.Seed != nil { + if err := mw.WriteField("seed", strconv.Itoa(*req.Seed)); err != nil { + return nil, err + } + } + if req.NumImagesPerPrompt != nil { + if err := mw.WriteField("num_images_per_prompt", strconv.Itoa(*req.NumImagesPerPrompt)); err != nil { + return nil, err + } + } + if req.NumInferenceSteps != nil { + if err := mw.WriteField("num_inference_steps", *req.NumInferenceSteps); err != nil { + return nil, err + } + } + if req.ControlnetConditioningScale != nil { + if err := mw.WriteField("controlnet_conditioning_scale", fmt.Sprintf("%f", *req.ControlnetConditioningScale)); err != nil { + return nil, err + } + } + if req.ControlGuidanceEnd != nil { + if err := mw.WriteField("control_guidance_end", fmt.Sprintf("%f", *req.ControlGuidanceEnd)); err != nil { + return nil, err + } + } + + if err := mw.Close(); err != nil { + return nil, err + } + + return mw, nil +} diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 587437ed..bb8ec9ee 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -99,6 +99,51 @@ type BodyGenImageToImage struct { Strength *float32 `json:"strength,omitempty"` } +// BodyGenImageToImageGeneric defines model for Body_genImageToImageGeneric. +type BodyGenImageToImageGeneric struct { + // ControlGuidanceEnd The percentage of total steps at which the ControlNet stops applying. + ControlGuidanceEnd *float32 `json:"control_guidance_end,omitempty"` + + // ControlnetConditioningScale Determines how much weight to assign to the conditioning inputs. + ControlnetConditioningScale *float32 `json:"controlnet_conditioning_scale,omitempty"` + + // GuidanceScale Encourages model to generate images closely linked to the text prompt (higher values may reduce image quality). + GuidanceScale *string `json:"guidance_scale,omitempty"` + + // Image Uploaded image to modify with the pipeline. + Image openapi_types.File `json:"image"` + + // Loras A LoRA (Low-Rank Adaptation) model and its corresponding weight for image generation. Example: { "latent-consistency/lcm-lora-sdxl": 1.0, "nerijs/pixel-art-xl": 1.2}. + Loras *string `json:"loras,omitempty"` + + // MaskImage Mask image to determine which regions of an image to fill in for inpainting task with the form HxW. + MaskImage *string `json:"mask_image,omitempty"` + + // ModelId Hugging Face model ID used for image generation. + ModelId string `json:"model_id"` + + // NegativePrompt Text prompt(s) to guide what to exclude from image generation. Ignored if guidance_scale < 1. + NegativePrompt *string `json:"negative_prompt,omitempty"` + + // NumImagesPerPrompt Number of images to generate per prompt. + NumImagesPerPrompt *int `json:"num_images_per_prompt,omitempty"` + + // NumInferenceSteps Number of denoising steps. More steps usually lead to higher quality images but slower inference. Modulated by strength. + NumInferenceSteps *string `json:"num_inference_steps,omitempty"` + + // Prompt Text prompt(s) to guide image generation. + Prompt string `json:"prompt"` + + // SafetyCheck Perform a safety check to estimate if generated images could be offensive or harmful. + SafetyCheck *bool `json:"safety_check,omitempty"` + + // Seed Seed for random number generation. + Seed *int `json:"seed,omitempty"` + + // Strength Degree of transformation applied to the reference image (0 to 1). + Strength *float32 `json:"strength,omitempty"` +} + // BodyGenImageToText defines model for Body_genImageToText. type BodyGenImageToText struct { // Image Uploaded image to transform with the pipeline. @@ -449,6 +494,9 @@ type GenAudioToTextMultipartRequestBody = BodyGenAudioToText // GenImageToImageMultipartRequestBody defines body for GenImageToImage for multipart/form-data ContentType. type GenImageToImageMultipartRequestBody = BodyGenImageToImage +// GenImageToImageGenericMultipartRequestBody defines body for GenImageToImageGeneric for multipart/form-data ContentType. +type GenImageToImageGenericMultipartRequestBody = BodyGenImageToImageGeneric + // GenImageToTextMultipartRequestBody defines body for GenImageToText for multipart/form-data ContentType. type GenImageToTextMultipartRequestBody = BodyGenImageToText @@ -623,6 +671,9 @@ type ClientInterface interface { // GenImageToImageWithBody request with any body GenImageToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + // GenImageToImageGenericWithBody request with any body + GenImageToImageGenericWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + // GenImageToTextWithBody request with any body GenImageToTextWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -716,6 +767,18 @@ func (c *Client) GenImageToImageWithBody(ctx context.Context, contentType string return c.Client.Do(req) } +func (c *Client) GenImageToImageGenericWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGenImageToImageGenericRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + func (c *Client) GenImageToTextWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { req, err := NewGenImageToTextRequestWithBody(c.Server, contentType, body) if err != nil { @@ -999,6 +1062,35 @@ func NewGenImageToImageRequestWithBody(server string, contentType string, body i return req, nil } +// NewGenImageToImageGenericRequestWithBody generates requests for GenImageToImageGeneric with any type of body +func NewGenImageToImageGenericRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/image-to-image-generic") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + // NewGenImageToTextRequestWithBody generates requests for GenImageToText with any type of body func NewGenImageToTextRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { var err error @@ -1333,6 +1425,9 @@ type ClientWithResponsesInterface interface { // GenImageToImageWithBodyWithResponse request with any body GenImageToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenImageToImageResponse, error) + // GenImageToImageGenericWithBodyWithResponse request with any body + GenImageToImageGenericWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenImageToImageGenericResponse, error) + // GenImageToTextWithBodyWithResponse request with any body GenImageToTextWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenImageToTextResponse, error) @@ -1486,6 +1581,32 @@ func (r GenImageToImageResponse) StatusCode() int { return 0 } +type GenImageToImageGenericResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *ImageResponse + JSON400 *HTTPError + JSON401 *HTTPError + JSON422 *HTTPValidationError + JSON500 *HTTPError +} + +// Status returns HTTPResponse.Status +func (r GenImageToImageGenericResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GenImageToImageGenericResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + type GenImageToTextResponse struct { Body []byte HTTPResponse *http.Response @@ -1740,6 +1861,15 @@ func (c *ClientWithResponses) GenImageToImageWithBodyWithResponse(ctx context.Co return ParseGenImageToImageResponse(rsp) } +// GenImageToImageGenericWithBodyWithResponse request with arbitrary body returning *GenImageToImageGenericResponse +func (c *ClientWithResponses) GenImageToImageGenericWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenImageToImageGenericResponse, error) { + rsp, err := c.GenImageToImageGenericWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseGenImageToImageGenericResponse(rsp) +} + // GenImageToTextWithBodyWithResponse request with arbitrary body returning *GenImageToTextResponse func (c *ClientWithResponses) GenImageToTextWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenImageToTextResponse, error) { rsp, err := c.GenImageToTextWithBody(ctx, contentType, body, reqEditors...) @@ -2044,6 +2174,60 @@ func ParseGenImageToImageResponse(rsp *http.Response) (*GenImageToImageResponse, return response, nil } +// ParseGenImageToImageGenericResponse parses an HTTP response from a GenImageToImageGenericWithResponse call +func ParseGenImageToImageGenericResponse(rsp *http.Response) (*GenImageToImageGenericResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GenImageToImageGenericResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest ImageResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest HTTPValidationError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON500 = &dest + + } + + return response, nil +} + // ParseGenImageToTextResponse parses an HTTP response from a GenImageToTextWithResponse call func ParseGenImageToTextResponse(rsp *http.Response) (*GenImageToTextResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) @@ -2500,6 +2684,9 @@ type ServerInterface interface { // Image To Image // (POST /image-to-image) GenImageToImage(w http.ResponseWriter, r *http.Request) + // Image To Image Generic + // (POST /image-to-image-generic) + GenImageToImageGeneric(w http.ResponseWriter, r *http.Request) // Image To Text // (POST /image-to-text) GenImageToText(w http.ResponseWriter, r *http.Request) @@ -2560,6 +2747,12 @@ func (_ Unimplemented) GenImageToImage(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotImplemented) } +// Image To Image Generic +// (POST /image-to-image-generic) +func (_ Unimplemented) GenImageToImageGeneric(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} + // Image To Text // (POST /image-to-text) func (_ Unimplemented) GenImageToText(w http.ResponseWriter, r *http.Request) { @@ -2696,6 +2889,23 @@ func (siw *ServerInterfaceWrapper) GenImageToImage(w http.ResponseWriter, r *htt handler.ServeHTTP(w, r.WithContext(ctx)) } +// GenImageToImageGeneric operation middleware +func (siw *ServerInterfaceWrapper) GenImageToImageGeneric(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.GenImageToImageGeneric(w, r) + })) + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) +} + // GenImageToText operation middleware func (siw *ServerInterfaceWrapper) GenImageToText(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -2960,6 +3170,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/image-to-image", wrapper.GenImageToImage) }) + r.Group(func(r chi.Router) { + r.Post(options.BaseURL+"/image-to-image-generic", wrapper.GenImageToImageGeneric) + }) r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/image-to-text", wrapper.GenImageToText) }) @@ -2991,87 +3204,92 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xdeW/jtrb/KoTeA5IB7GzttA8B7h+ZpTPBTaZBljst2sCXlo5lTiRSJakk7rx89wcu", - "kkiJsuU0Sft6/dc4Epez/s4heaj5GsUsLxgFKkV0+DUS8RxyrH8enR2/55xx9TsBEXNSSMJodKjeIFCv", - "EAdRMCoA5SyBbCcaRQVnBXBJQI+Ri7Tb/XIOtnsOQuAUVD9JZAbRYXQqUvXXolB/CMkJTaOHh1HE4beS", - "cEiiw1/0qNdNl5rQuh+bfoFYRg+j6KhMCDu3VHZJOffoRzPGEVY9UAoUOFatukzpFvpHlv04iw5/+Rr9", - "N4dZdBj9124jzV0ryt1TSAi+Oj+JHq5HAUnYmSAxM+90uDXTufx6PAWYfsOSxSQFqhtesku4l4rcHi58", - "kq6KjOGkogbNSAZIMjQFJDmmquUUEiWTGeM5ltFhNCUU80XUoq+rxFGUg8QJltjMOsNlpvp/fYjacjlK", - "EqJ+4gx9YVNEqJmMMGppKbAQkKg/5BxQQQrICPXtqJorRIdS9oQkPh0dKj6WaUpoin7AcWUgx+9QqSZW", - "hlLJo6ispJ7aNE1CU3OQJacTSXIQEueF8GmQvIQOHee6D2r6mOnnnkqQhHu5gy7KomBcWdMtzkoQh2hL", - "AJVAY9gaoa07xpOtEVJmjgxRaMpYBpii7S01+ZZ6tzXDmYCtVzvonaEMEYHs6+1mvFc7VUuUA6YCUeYQ", - "uWNns+/U7/EUa601bRypWS4vG8msgoGOY4Tsfol7HOc4hUum/+n6R1qSBNMYJiLGGXhq+n7ndVtH72nM", - "So5TENZSZI0hgEiuX8QZE5AtUEboTWO8Sm+o4CwvJNqek3QO3OoO5XiBOCRlbIdAv5U4I3LxypXbB0sn", - "utB01vzSMp8CV/ySisEeTzdjS6YoJ7MFuiNy3vGrfnc38gvYuh53skSO+105voOUgybmbk5iQ0aDkIZS", - "IlBRirkW4R3midCtCCWS4My02WnTh1aLKWMcixWQcIRO2PkR2j5hd+NzTG/QUYILqZHplVU8pgkiUqCY", - "cRMdE+Vld0DSudSOa5hwAgx6f4/zIoND9BX9GmVYApXjmFFBhHK0xW4W52NF3Vgk99mv0SHa39kboV8j", - "Cpx8EbsFuYdsjLkcV28PHlwBnGjGng0HO/wMhEIKKZbkFibG+FcQcdm4ybZ4pd2rJAmguzmW6i+4j7My", - "ATTjLA+I+DiljCsLmiHfINGv5d7eNzHad8n+ZElDZ4a0EPVlPjF+PSmAh3jYb7PwSZsaYrMKEFyMKIBb", - "9jxCyhwdm8ZnwDvkECohNdar6aEz4KBZk9AKLft7e/30JEAZEUrHuuMOOmUczG9UihJnCrUAa8yyEGWh", - "qGJlWkokMnYHHNVUqGGSMtOeO12oeAM0lfMOf1V7dKGpDnHnineIVSyzyX6dCjwDuZjEc4hvPOGp0NeW", - "3hlwhYkqkOpuSHfTpigkyTXuz9rYpWChzBKVwrDZDKhQRsY4mmOez8rMJfPCjPpWE1MTa6O1phYg6Urk", - "AqxbckwTliODbz2iUI2D8q505Ulhb+d/euCazUwq0qRpuCgy0gQ5DpWOjWa299SbfS+QXVRzdrC5FfeL", - "SoEmsAUSAC+yr84Awgny4LBZs/5kkfMJE9RaJUNh+Q+hcf+UfV7X0u0qlQ7M6f5FEmBdlc5aoPhdaEE2", - "4zgHoQFZQMxoos3by0Nu1fAudz/04NZch31vztffB2c1LRGhSIdzMWDSj2bw0LyDbbeOP9iMr+Pnn2q1", - "hoz104mcqdaTaRnfgGxTsX/wfZuMq2pCpWK92lREKZHjnJVUKgWYMevllptQaJ2ZUKheWZhVP3MVO23P", - "O5JlCuwJ1a86Kjw1zd5ooj3G3NDOiIAJLtNJDyzvHXTy1JoF3RnhJGnA2GPYpMvoo7fwsIsODgLyaabT", - "5t6+JuGlMQcsKr69EK8JOCpT1A/wq9OXg9f/j7OXTV5RSeKOJC3r3d87+DaEh7rlWnD4WY/dnXXNCGNC", - "x5IQcwFpDlQe0YWcE5oedMPMlN0HNk1Rpg0IfYsw53iBUnILFGGBMJqy+2oLwPqZxsWR4v+nn3/6GRk0", - "drl9w+5719zdyY8rvBeG+MciPBY3E0KLUgb5Y3djDoJlpQY11Rjpxi2m5KIgsfZKvVjDqOBwS1gp1I+E", - "xLo3kdauRk1Wpf1i//7j/We0/fEfn/9x8Po7bZIXR6deJnmqZj7WZP7lVr15mSkvFjcTVspakEvw4Fjl", - "1iWMGgmaqMLtruBcJeBqQLMtiPMpSUslTCN6Y1ZihNhMAlV/JmWs9/1ASuC2p5xjqhCH0DQDRw0eVxXl", - "6EdDeQg8qDKqjPwOk5gxnoj12CsYoRLpnoRiCaIOoPW4zZIC0xTQL3uj/WtrIrq3nRfBfQGxNM2nYBpw", - "EOqhemTUl5BcYSWjwo9Ydi701vAQYtSdrOsMn+4PrJezmeXKKqLlC3dz4IAAx5Z8RJTi0PZPo59fNejn", - "JdK6WZsyJ3/XhGV4ClmAsBP9vM5oPNIqavYRoQmJtfyxagopZyVNbGsV7/e8JlMc37hNuuSaaZdsiGcs", - "JXINazHdBCrpWHmAmLNMZTjaPM1YiFAhVdRnM0Wixjj9PrDpfGJm7+p5aOzoxIQl8eOqqHdCH7ngfOJ9", - "2qcBxNKwlTx+P3BFCvj96/+gDaxB0tzsZK3KONfeOaqcM+C/b+clvQnlPbF6oRNUpUztlbg55OqeH0u7", - "3dRNevUANtPVo7os+lsfja7rmXrGrF53BiYSckXQgzNHPVY9kQ5jHUlKt6EizJGlEVRAgh/Ort6yvCgl", - "HNNZ4Oz5tD6ET0Biosz/w9kVik0f9xi4K1QDXzXWhXMv/MWUMzSJ4he3bMDx1xxyxheTGQfwOujH6Af1", - "eEk3ySTOAv0u9fNgR0JbpOkHwe0AnHs0fVJ/r9xXUwKhpqVHpM9qJaOKIEerLeWF1XslSUZ+1ypapWKl", - "2bJpjoTEkghJYvFI5b6wxoapYRQ5PE6sJbvdHIkhK9/gdO4whua+UQzlAxbCQ00ixEGQIN9a2rYQsJiP", - "l5dnPTVG6tXAIiMDFsMLcup6oW5BzrsKd8zMHuK05WendZhu2Onh9V84I4kerua6j5UKnJdy0h7PQXLD", - "SQjGXWrbA4Toxjy5w1x7vZXFoDoq5d9LETstSpXvmUqquuDnzGuzjPkWIDmcfShK1GdybrI76DjE5tZu", - "+7Pq2SrULZqG9byjhnHXdAJSXqKMC4mlGKQGDjgbq2itFbIMZBVZohr3kQpp+3xLKYbsv65WDPsBtfQS", - "/hFwJudvqzzbl6garhThtGyuOyLTpErNHMqAlrki9sd/RqPo/fn5j+fRKDp+d/LeJe/CTLCKYUuHy5dD", - "doArvVRcq2oytEwJrG57hNFeW7jZ6eraStfKTLHEquzV0nLd6resptI5f1xLMDqlXyaX/tVAIxVd27dy", - "KdBO0VtpeYiDAKMnJ6enpiy3a9AxoxKodL3urX0U2tphmeeg5yxb7ZzcNKpmcuh3CAuTfQ6/lSACJ/k5", - "vp9IdgO0fab0nbtnfI8uTZtwmqhnFoNDs0Ptg1uVaodpG6iFos4WjAeGwZW+5IBzr5+uyfTrKnAeXHhL", - "yAtlYiWH1pHi966xNY0CJ3eSFRN/j2G873RmBfpnUKKqX9EumHK7na0sBKmV4ltJZQd9VtJ4b8u6OShv", - "86zbPgqesQ9ZlFRKXalH7pDV7E623dRZ9WtbnZTCJ9iYMLoSQ7ZDuDO+M9zILA1yS2ollraQl0AIuQV9", - "jGZP084wx8ZjunDCWTYpebZi7/Hq/ESjqSinut6Z0BTdEowuOYlv9IkFkyxmmd2JTPTGtD10z8itPXkf", - "SzZulxegQhPnoutbQxa64kFNwa1y9DWILsppRsRc0Wz69pNeJQBVXoBpgjKWeuS9N2P0UDdwM1etVttZ", - "B5IM8ZJ25aZemB9f2HQHfWKSxICkrrycE4GIQGodmaBq8qquoqpgNZtiTM6BI85KCWKkN1yJRAkDgSiT", - "pkRMzYRRcKfZlFzAPY6lebYtXqEECqCJQIz6nJC8yCAHKm01Gk1QrutrprqEYUbSkuNpBloTque/jRn8", - "G2GeltUJ6aBMtLbtWtpfHzrHFrY2WTcGCdw5gQncXbD+EvAsa0qN8fmLVyEJNRwr27PqZaVMmdkaV4FA", - "idgO481pHvWZVeV4EJ76gpU8BndWQmOW+7PWYyDpHdxf1M+Dk7eTWY8SXyQuRIUxaABarZXeLQeXbra3", - "PuRpNKtKk5ZP94wY5gKvZBWOLYOswYj1tDaNJPvzrPqljXpZED7F4kasZcumb1Ui0mPA7kFte8HC8d0I", - "ldQ5q28qCQTaNl1f1dCnSw/8iwr+MaxfeLJy7dgZT4sgqPeY8b61qJbHljBhItHHeqa5pluf1PtTelBm", - "Bl55gdESJqrmVqrXLdqX6lcvfQPHUbl6USlT4Q0mplrPuUWHp6yUrYIq3a+rcCpmd91pPs9BVrWPZsI7", - "LNAsw2kKCcICfbr44bN3UKqGGX74pzSh3pjzZbdQtZ5xUMFZ0K/V4MqpTblDw0KMqUoQcByDEOaWY73B", - "PsCJjesKQ4oWm6tPra4+PV6dn4RUqdGXs9xehuql0tfYS/Pc5lIxE2D06bdO9EmmGLJ5Yg49h+8rmSPM", - "h9aZZmjZ/rzbN6OKx2u/9zJgUO/t3Ye+hdff527jUxbYd24OLimw31wW3FwW/PteFnz9H31XEF2AWqhL", - "QLpYtzCbNLp4U+9jbP3vljINUV+1ny6aks5NfdafdiOgg98DbwRYg2mFWD+E9sbZiwIgnvcFWo8LF7KO", - "UK7wRBSAb4CjBNTKngul40yBf7ZAcF9wEFpvKkxgqlWdqD4Qz6s6M2V02lbV40S3LIiMted0ltLVX0p2", - "1dRqCSsBbLql/jLjh/XoDPKMFxeHULIsWjRJ2fIQYUrk9f7Gsql68zXfXjxTCBjMyuqPjMXe+RKmC1vO", - "0ubwa8emrx/cGB63KgiaUybzQZ7WeVRQhvpB01TTjC7V01Wpq+LDTGVbOq41oOJk/V241ftu5lLnqkS9", - "ugKp2nprhTUPo9trhOqWqCFixeG0JdWV2fK9Ho3QccmJXFwoUgyfHy8vz94A5sDr7zppWDeP6kHmUhbR", - "gxqDBAv2juxd7rj+/A4vKTo6rvf93I2+E3ILhcKSo2N0XlKqJ1K4Zsba29nb2VMCYQVQXJDoMPpmZ39n", - "T2kLy7kme1d/1WUs2bhy4oKJUDSvP33jfKnI3K6wqy1WWGs4TtRSov1ZGG5OCd+wZNE62jZRH3O5q8Lu", - "uPpikVHzKiMIfYPmwVexivHOiZ9m+2Bvr0WFI/XdL8LEj2EkeAtEPXcrcJd6sT8rM9Q0G0XfPiEJTSVc", - "YP43OEHVGa2ed/9l5r2iuJRzxsnvkOiJ9795mYkts+g9lSoNvmQMnWBuKgO+3X/9Utw3CatGKoPlioSD", - "gycloVOV2CWmaYLqysXXL2V/x1QCpzhDF8BvgVcUODCqY64LoL9cP1yPIlHmOeaL6hNn6JKhKjXAqVDY", - "XYUShd73Y5NiYbEYU5zDmN0C5yTRyO+hwyjandtCs90KhVPQIvBBzK0SjJ4RQULViEOB5MGVUzWQKcf0", - "Oa1rDZeyWlXePTuvZqI/xmU1hmJTV9j1s2dePydfTonf47gyJGpu9NJKBeX6jlo4Kh8VRbaoLqp53wIR", - "5mi/4EwlWc5irROmWx9veeY47c32woHaLzrcROr+SL2JUOtGKHPj/5Kh+trnmiGK+I7hgsCAzFxvWBkc", - "WJ2Y+9/2eRmH/zMS81AF7sbr/+L5+QZ6Hg09j0yOieehLvDc1p/1CiLPh9DHrNZKOqqPv7wMBpnZXhiE", - "/M2kDfxsko5n8Pz6I0qPc/3KMUbRbkZuYexXPK5afgQXHk41s6ndcz9OKUtOIUFAE/39EhGEiHbx3VKY", - "eLyOegpXXxgleisNN4CxAYynAwxlZgYs/ghqZG3PNMiR5QNSBX3WWOp6BowyTNNSQVh9lN9FgZPT53L8", - "5ubSSzu7c51n498b/35C/9besrY/Z7lxYVuKPsb2u1rjg36Ptp/gsoXP+vYXpksy/sAnu5456+/M+MJu", - "7peUbxx94+hP5+iV91XGjQ4e4fei6yCjaFdF6AFHDx9aFcl67e8UIIeTeqfS65nCereWbHPKsHH7v4nb", - "6yq6P3DIIB3385zd1OMN2urzu7j/b5n576aqu8HVJqBsKv8wTZwSTO8/8+pBClPj96xQ4ZURvjBW+P+1", - "3AYrNljx9FhRu9DjwMJ212hROp/SDcKE/ZxnvRJA00X1fxXoK5FSoOaL5UG3bz4I+syrg2qiTXaw8fi/", - "icc7H9Nd09VL1xmEJkDo6VpfM6/qjd9mrEzQW5bnJSVygT5gCXd4EdkLwLrKWRzu7iYccD5OzdudzHbf", - "iVV3XVbfM/6F1FlF37D1QEK328UF2Z2CxLs1vw/XD/8XAAD//+pb7T5rdwAA", + "H4sIAAAAAAAC/+xde3Pbtpb/KhjuziSZkeRHm3bXM/cPJ00Tz7VTjx+37aQeX4g8ohCTAAuAttWsv/sO", + "XiRAgpLs2m5vq7/iSHic5+8cAAfQlyRlZcUoUCmSvS+JSOdQYv3n/vHBO84ZV39nIFJOKkkYTfbUNwjU", + "V4iDqBgVgEqWQTFJRknFWQVcEtBjlCLvdz+bg+1eghA4B9VPEllAspcciVz9b1Gp/wjJCc2Tu7tRwuHX", + "mnDIkr1PetSLtktDaNOPTT9DKpO7UbJfZ4SdWCr7pJwE9KMZ4wirHigHChyrVn2mdAv9R1H8MEv2Pn1J", + "/pvDLNlL/murleaWFeXWEWQEn58cJncXo4gk7EyQmZknPW7NdD6/AU8Rpt+wbHGZA9UNz9gZ3EpF7gAX", + "IUnnVcFw5qhBM1IAkgxNAUmOqWo5hUzJZMZ4iWWyl0wJxXyRdOjrK3GUlCBxhiU2s85wXaj+X+6Srlz2", + "s4yoP3GBPrMpItRMRhi1tFRYCMjUf+QcUEUqKAgN7cjNFaNDKfuSZCEdPSo+1HlOaI6+x6kzkIPvUK0m", + "Vobi5FE5K2mmNk2z2NQcZM3ppSQlCInLSoQ0SF5Dj44T3Qe1fcz080AlSMKtnKDTuqoYV9Z0jYsaxB56", + "IYBKoCm8GKEXN4xnL0ZImTkyRKEpYwVgil6+UJO/UN+9mOFCwItXE/SdoQwRgezXL9vxXk1cS1QCpgJR", + "5hE5sbPZ79Tf4ynWWmvbeFKzXJ61klkFAz3HiNn9Evc4KHEOZ0z/0/ePvCYZpilcihQXEKjp28nrro7e", + "0ZTVHOcgrKXIBkMAkVJ/kRZMQLFABaFXrfEqvaGKs7KS6OWc5HPgVneoxAvEIatTOwT6tcYFkYtXvtze", + "WzrRqaaz4ZfW5RS44pc4Bgc83YwtmaKczBbohsh5z6+G3d3IL2LretzLJXLc6cvxO8g5aGJu5iQ1ZLQI", + "aSglAlW1mGsR3mCeCd2KUCIJLkybSZc+tFpMBeNYrICEfXTITvbRy0N2Mz7B9ArtZ7iSGpleWcVjmiEi", + "BUoZN9ExU152AySfS+24hgkvwKB3t7isCthDX9AvSYElUDlOGRVEKEdbbBVpOVbUjUV2W/yS7KGdyfYI", + "/ZJQ4OSz2KrILRRjzOXYfbt75wvgUDP2ZDjY42dNKKSQY0mu4dIY/woizlo3eSleafeqSQboZo6l+h/c", + "pkWdAZpxVkZEfJBTxpUFzVBokOiXenv7qxTt+GR/tKShY0NajPq6vDR+fVkBj/Gw02XhozY1xGYOEHyM", + "qIBb9gJC6hIdmMbHwHvkECohN9ar6aEz4KBZk9AJLTvb28P0ZEAZEUrHuuMEHTEO5m9UixoXCrUAa8yy", + "EGWhyLEyrSUSBbsBjhoq1DBZXWjPnS5UvAGay3mPP9cenWqqY9z54l3HKpbZ5LBOBZ6BXFymc0ivAuGp", + "0NeV3jFwhYkqkOpuSHfTpigkKTXuz7rYpWChLjKVwrDZDKhQRsY4mmNezurCJ/PUjPpWE9MQa6O1phYg", + "60vkFKxbckwzViKDbwOiUI2j8na6CqSwPfmfAbhmM5OKtGkarqqCtEGOg9Ox0czLbfXNThDITt2cPWzu", + "xP3KKdAEtkgCEET2NTOA90pCJO0nAimjkrOiDWRAs45c/jeW2FfAU6BSsavkwyQurEth6UW3t2b4jyCR", + "kEx9W1XFgtDcl41t1EaxdzSLxTBLKwV5marIo8ghNI/E3u1Y7JXAS0JBoDm7QWWdzl3ckgxhIUhOnUL9", + "0RGhVS1FhF4KUvHXthwMvkuyhOTTN5PXI7SzPdm+SP78eVcn//kD8q6/YSKDxdXlgLiPsLhqRZ05G7ce", + "yCEnjArloZi2zWakKBChhk1aYUKlYl6qsRo1afj/cPtjkPCoFoOa2SRcf5+EK/n01fYI7b7uY9afPe9q", + "Zb5JuzZp11Da5aHZigzMZVarE7H4TuXacbQRxqOF0kfcKWyUtC5c/y6UHp5yyA872l6VW6+5ufYvkgHr", + "q3TWActvYgn0jOMShAZqASrh1AYfbAhdq+F97r4fWEDOddoSzPn62+ispqUK/jodEWtM+sEMHpt3bdtt", + "4hI24+u4+odarSHj/mlGyVTry2mdXoHsUrGz+22XjHM3YbCmUCLHJaupVAowYzb73n6ioXVmYqP6ygKv", + "+rNUwdT2vFEJ3RSUWtVXPRUemWZvNNEBY37IZ0TAJa7zywGg3t7t5dkNC7ozwlnWwnO4iNL7luhDsBKx", + "qxAOAsppoZcbg31Nwk5TDlg4voOYrwnYr3M0DPmr05rd1//B20ibTMNJ4oZkHevd2d79OoaHuuW94PBH", + "PXZ/1ntGGBM6loSYU8hLoHKfLuSc0Hy3H2am7DZyeo0KbUDoa4Q5xwuUk2ugCAuE0ZTduj0B62caF0eK", + "/59+/ulnZNDY5/YNu73H4v/A4b0wxD8U4fWCl1a1jPLHbsYcBCtqDWqlXgGrxh2m5KIiqfZKvYjDqOJw", + "TVgt1B8ZSXVvIq1djTor353bD7c/opcf/vHjP3Zff6NN8nT/6FV/NazJ/NOthsu6UF4sri5ZLRtBLsGD", + "A5Vt1zBqJWiiCrfHs3OVkqsBzfksLqckr5UwjeiNWYkRYjMJVP03q1N9AAtSArc95RxThTiE5gV4agi4", + "cpSjHwzlMfCgyqgK8htcpozxTNyPvYoRKpHuSSiWIJoA2ozbLjIwzQF92h7tXFgT0b3tvAhuK0ilaT4F", + "04CDUB+qj4z6MlIqrGRUhBHLzoXeGh5ijPqT9Z3h4+2u9XI2s1xZRXR84WYOHBDg1JKPiFIcevnT6OdX", + "LfoFibRu1qXMy981YQWeQhEh7FB/3mQ0AWmOmh1EaEZSLX+smkLOWU0z21rF++2gyRSnV36TPrlm2iWV", + "CQXLibyHtZhuAtV0rDxAzFmhMhxtnmYsRKiQKuqzmSJRY5z+PnL6f2hm7+t53djRiwlL4sd51Ww2P3DB", + "+cgbt48DiLVhK3v4PuGKFPDb13+jk8S1pLnZ21qVcd77CM85Z8R/385rehXLe1L1hU5QlTK1V+K22qhf", + "yCftdlM/6dUD2ExXj+qzGG59tLpuZhoY033dG5hIKBVBd94czVjNRDqM9SQp/YaKME+WRlARCb4/Pn/L", + "yqqWcEBnkSLAo6YaMgOJiTL/98fnKDV9/Hq8vlANfDVYF8+98GdTV9omip/9+k3PX0soGV9czjhA0EF/", + "jL5XHy/ppk9dI/3O9OfRjoR2SNMfRLcDcBnQ9FH9f+W+mhIINS0DIkNWnYwcQZ5WO8qLq/dckoL8plW0", + "SsVKs3XbHAmJJRGSpOKByn1mja2nhlHi8XhpLdnv5kkMWflGp/OHMTQPjWIoX2MhvK5JxDiIEhRaS9cW", + "Ihbz4ezseKDYW321ZrW3AYv1K6Obwu1+ZfR3DnfMzAHidOVnp/WYbtkZ4PVfuCCZHq7heogVB85LOemO", + "5yG54SQG4z613QFidGOe3WCuvd7KYq2CduXfSxE7r2qV75mS9qby+jhos4z5DiB5nL2vajRkcn6yu9Zx", + "iM2t/fbH7rNVqFu1DZt5Ry3jvulEpLxEGacSS7GWGjjgYqyitVbIMpBVZAk37gMV0vX5jlIM2X9erRj2", + "I2oZJPwD4ELO37o8O5SoGq4W8bRsrjsi08SlZh5lQOtSEfvDP5NR8u7k5IeTZJQcfHf4zifv1EywimFL", + "h8+XR3aEK71UvNf1ldgyJbK6HRBGd23hZ6erL7n4VmaKKFZlr5aWi06/ZZdbvPPHewlGp/TL5DK8Gmil", + "oi9ZrFwKdFP0Tloe4yDC6OHh0ZG5HxWvSQQqfa97az+Kbe2wInDQE1asdk5uGrmZPPo9wuJkn8CvNYjI", + "SX6Jby8luwLaPVP6xt8zvkVnpk08TdQzi7VDs0ftnX89yA7TNVALRb0tmAAMoyt9yQGXQT99OSastMBl", + "dOEtoayUidW8W6X5rW9sbaPIyZ1k1WW4xzDe8TqzCv0zKlHVr+oWUvndjleWhjRKCa3E2cGQlbTe27Fu", + "DsrbAuu2H0XP2NdZlDilrtQj98hqdye7buqt+rWtXtYiJNiYMDoX62yHcG98b7iRWRqUllQnlq6Ql0AI", + "uQZ9jGZP044xx8Zj4iXONS9W7D2enxxqNBX1VF88IzRH1wSjM07SK31iwSRLWWF3IjO9MW0P3QtybU/e", + "x5KNu+UFqNLExaqez3lUU3CtHP0eRFf1tCBirmg2fYdJdwmAywswzVDB8oC8d2aMAerW3MxVq9Vu1oEk", + "Q7ymfbmpL8wfn9l0gj4ySVJAUldkzolARCC1jsyQm9zVVbgKXLMpxuQcOOKsliBGesOVSJQxEIgyaYrG", + "dMU3iu40m5ILuMWpNJ+9FK9QBhXQTCBGQ05IWRVQApW2Po1mqNT1NVNdwjAjec3xtACtCdXz38YM/o0w", + "z2t3QrpWJtrYdiPtL3e9Ywt7SUw3BgncO4GJXCK1/hLxLGtKrfGFi1chCTUcK9uz6mW1zJnZGleBQInY", + "DhPMaT4aMivneBCf+pTVPAV/VkJTVoazNmMgGRzcnzafRyfvJrMBJaFIfIiKY9AaaHWv9G45uPSzvftD", + "nkYzV5q0fLonxDAfeCVzOLYMstZGrMe1aSTZH2fVz23Uy4LwERZX4l62bPq6EpEBA/YParsLFo5vRqim", + "3ll9W0kg0EvT9VUDfbr0ILxoER7DhoUnK9eOvfG0CKJ6TxkfWotqebwQJkxk+ljPNNd065P6cMoAyszA", + "K1+SsIQJ19xK9aJD+1L96qVv5DiqVF84ZSq8wcRel/IKtqeslp2CKt2vr3AqZjf9aX6cg3S1j2bCGyzQ", + "rMB5DhnCAn08/T64ofJRDbP+4Z/ShPrGnC/7harNjGsVnEX9Wg2unNqUO7QspJiqBAGnKQhhnptoNtjX", + "cGLjusKQosXm61Ora0iP5yeHMVVq9OWstLeGBqkMNfbcPHe5VMxEGH38rRN9kinW2Twxh57r7yuZI8y7", + "zplmbNn+tNs3I8fjRdh7GTCo7+0ViKGF11/nkYnHLLDvPeGwpMB+82rD5hLhX/cS4eu/9aMN6BTUQl0C", + "0sW6ldmk0cWbeh/jxf+9UKYhmjePpou2pHNTn/WH3Qjo4feaNwKswXRCbBhCB+PsaQWQzocCbcCFD1n7", + "qFR4IirAV8BRBmplz4XScaHAv1gguK04CK03FSYw1arOVB9I567OTBmdtlX1caZbVkSm2nN6S2n3PyU7", + "N7VawkoAm26p/5nx43r0BnnCi4vrULIsWrRJ2fIQYUrk9f7GsqkG87XQXgJTiBjMyuqPgqXB+RKmC1vO", + "0uXwS8+mL+78GJ52KgjaUybzMmLnPCoqQ/1B21TTjM7Up6tSV8WHmcq29FxrjYqT++/Crd53M5c6VyXq", + "7gqkahusFe55GN1dI7hbooaIFYfTllRfZsv3ejRCpzUncnGqSDF8fjg7O34DmANvHtjUsG4+agaZS1kl", + "d2oMEi3Y27e3u9PmHUReU7R/0Oz7+Rt9h+QaKoUl+wfopKZUT6RwzYy1PdmebCuBsAoorkiyl3w12Zls", + "K21hOddkb+nn9caSjZ0TV0zEonnzBqH3ZKS5XWFXW6yy1nCQqaVE930+bk4J37Bs0TnaNlEfc7mlwu7Y", + "PR1p1LzKCGKPAd6FKlYx3jvx02zvbm93qPCkvvVZmPixHgnBAlHP3QnctV7sz+oCtc1GydePSEJbCReZ", + "/w3OkDuj1fPuPM+85xTXcs44+Q0yPfHOV88zsWUWvaNSpcFnjKFDzE1lwNc7r5+L+zZh1UhlsFyRsLv7", + "qCT0qhL7xLRNUFO5+Pq57O+ASuAUF+gU+DVwR4EHozrm+gD66eLuYpSIuiwxX7i3ZtEZQy41wLlQ2O1C", + "iULv27FJsbBYjCkuYcyugXOSaeQP0GGUbM1todmWQ+EctAhCEPOrBJMnRJBYNeK6QHLny8kNZMoxQ06b", + "WsOlrLrKuyfn1Uz0+7h0Yyg2dYXdMHvm66fkyyvxexhXhkTNjV5aqaDc3FGLR+X9qioW7qJa8BaIMEf7", + "FWcqyfIWa70w3XlF74njdDDbMwfqsOhwE6mHI/UmQt03Qpkb/2esfRXuniGKhI7RA4Fx7r1c+RhggHCa", + "6ovfefPYyJyR1JQpYXElRohMYDJCrJbuabyR90zeCIkrkOl8dz1oaZ+HekaEcZNugGYDNH9JoEHeq2u/", + "A3BaP/FxZ40dAb1RbuBk9YZA+KbY88DAH7EhEKv834DAn3xfYINED0aiBy7KSeChPvBcN88JRpHnfewR", + "vXstdtyjU8+DQWa2ZwahcBN7Az+bHOQJPL95vO1hru8cY5RsFeQaxmGl9aqVTnSN492iMDXD/jO5suYU", + "MgQ00+8miShEdIt+l8LEw3U0UDD/zCgxWOG8AYwNYDweYCgzM2Dxe1Cj6HqmQY6iXCNV0DUOta6jwqjA", + "NK8VhDUlRH0UODx6Ksdvb0w+t7N71wg3/r3x70f0b+0t9/bnojQubK/AjLF9z2+8O+zR9uk/e+FC3zp1", + "P+kRdeTIU4FPnPX3ZnxmNw+vsmwcfePoj+fozvuccaPdB/i96DvIKNlSEXqNI8/3nZsQeu3vXXyIJ/Ve", + "hekThfV+Devm0GHj9n8Rt9fVu7/jcFN67hc4u6kDXmurL+zi/3C1+b1h9yaB2wSUbcUxpplX+h38mvMA", + "Upja4ieFiqB8+ZmxIvxt8Q1WbLDi8bGicaGHgYXtrtGi9p7wjsKEfUa4/XG/6cL9Roq+ii0Fan8pIer2", + "7UPET7w6cBNtsoONx/9FPN57xPuerl77ziA0AUJP1/kVBXfP4W3B6gy9ZWVZUyIX6D2WcIMXiX14QN+u", + "EHtbWxkHXI5z8+2ksN0nqequr/MMjH8qdVYxNGwzkNDttnBFtqYg8VbD793F3f8HAAD//44auThshQAA", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index 6c5a2f4e..c1b0c639 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -653,6 +653,64 @@ func (w *Worker) LiveVideoToVideo(ctx context.Context, req GenLiveVideoToVideoJS return resp.JSON200, nil } +func (w *Worker) ImageToImageGeneric(ctx context.Context, req GenImageToImageGenericMultipartRequestBody) (*ImageResponse, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + c, err := w.borrowContainer(ctx, "image-to-image-generic", *req.ModelId) + if err != nil { + return nil, err + } + + var buf bytes.Buffer + mw, err := NewImageToImageGenericMultipartWriter(&buf, req) + if err != nil { + return nil, err + } + + resp, err := c.Client.GenImageToImageGenericWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) + if err != nil { + return nil, err + } + + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) + if err != nil { + return nil, err + } + slog.Error("image-to-image-generic container returned 400", slog.String("err", string(val))) + return nil, errors.New("image-to-image-generic container returned 400: " + resp.JSON400.Detail.Msg) + } + + if resp.JSON401 != nil { + val, err := json.Marshal(resp.JSON401) + if err != nil { + return nil, err + } + slog.Error("image-to-image-generic container returned 401", slog.String("err", string(val))) + return nil, errors.New("image-to-image-generic container returned 401: " + resp.JSON401.Detail.Msg) + } + + if resp.JSON422 != nil { + val, err := json.Marshal(resp.JSON422) + if err != nil { + return nil, err + } + slog.Error("image-to-image-generic container returned 422", slog.String("err", string(val))) + return nil, errors.New("image-to-image-generic container returned 422: " + string(val)) + } + + if resp.JSON500 != nil { + val, err := json.Marshal(resp.JSON500) + if err != nil { + return nil, err + } + slog.Error("image-to-image-generic container returned 500", slog.String("err", string(val))) + return nil, errors.New("image-to-image-generic container returned 500: " + resp.JSON500.Detail.Msg) + } + + return resp.JSON200, nil +} + func (w *Worker) EnsureImageAvailable(ctx context.Context, pipeline string, modelID string) error { return w.manager.EnsureImageAvailable(ctx, pipeline, modelID) }