Skip to content

Commit

Permalink
add height width to image inputs to control output image dims
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 10, 2025
1 parent c9e0d0f commit 4051f76
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions tests/onnxruntime/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
DiffusionPipeline,
)
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.utils import load_image
from parameterized import parameterized
from PIL import Image
from transformers.testing_utils import require_torch_gpu
from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin

Expand All @@ -34,7 +34,7 @@
ORTPipelineForInpainting,
ORTPipelineForText2Image,
)
from optimum.utils import is_transformers_version
from optimum.utils import is_diffusers_version, is_transformers_version
from optimum.utils.testing_utils import grid_parameters, require_diffusers


Expand All @@ -54,15 +54,13 @@ def _generate_prompts(batch_size=1):
"guidance_scale": 7.5,
"output_type": "np",
}

return inputs


def _generate_images(height=128, width=128, batch_size=1, channel=3, input_type="pil"):
if input_type == "pil":
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
).resize((width, height))
image = Image.new("RGB", (width, height))
elif input_type == "np":
image = np.random.rand(height, width, channel)
elif input_type == "pt":
Expand Down Expand Up @@ -105,8 +103,7 @@ class ORTPipelineForText2ImageTest(ORTModelTestMixin):
def generate_inputs(self, height=128, width=128, batch_size=1):
inputs = _generate_prompts(batch_size=batch_size)

inputs["height"] = height
inputs["width"] = width
inputs["height"], inputs["width"] = height, width

return inputs

Expand Down Expand Up @@ -229,7 +226,11 @@ def test_shape(self, model_arch: str):

if model_arch == "flux":
channels = pipeline.transformer.config.in_channels
expected_shape = (batch_size, expected_height * expected_width, channels)
if is_diffusers_version(">=", "0.32.0"):
expected_shape = (batch_size, expected_height * expected_width // 4, channels)
else:
expected_shape = (batch_size, expected_height * expected_width, channels)

elif model_arch == "stable-diffusion-3":
out_channels = pipeline.transformer.config.out_channels
expected_shape = (batch_size, out_channels, expected_height, expected_width)
Expand Down Expand Up @@ -363,6 +364,7 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type
)

inputs["height"], inputs["width"] = height, width
inputs["strength"] = 0.75

return inputs
Expand Down Expand Up @@ -491,7 +493,7 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
ort_images = ort_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_images = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images

np.testing.assert_allclose(ort_images, diffusers_images, atol=1e-4, rtol=1e-2)
np.testing.assert_allclose(ort_images, diffusers_images, atol=3e-4, rtol=1e-2)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_diffusers
Expand Down Expand Up @@ -602,9 +604,8 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
height=height, width=width, batch_size=batch_size, channel=1, input_type=input_type
)

inputs["height"], inputs["width"] = height, width
inputs["strength"] = 0.75
inputs["height"] = height
inputs["width"] = width

return inputs

Expand Down

0 comments on commit 4051f76

Please sign in to comment.