Skip to content

Commit

Permalink
Refactor flux.py and flux.xml for improved model selection (bgruening…
Browse files Browse the repository at this point in the history
  • Loading branch information
arash77 authored Oct 29, 2024
1 parent 6954dde commit 0959e15
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
17 changes: 13 additions & 4 deletions tools/flux/flux.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import sys

import torch
from diffusers import FluxPipeline

model = sys.argv[1]
model_path = sys.argv[1]

prompt_type = sys.argv[2]
if prompt_type == "file":
Expand All @@ -12,20 +13,28 @@
elif prompt_type == "text":
prompt = sys.argv[3]

if model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
if "dev" in model_path:
num_inference_steps = 20
elif "schnell" in model_path:
num_inference_steps = 4
else:
print("Invalid model!")
sys.exit(1)

snapshots = []
for d in os.listdir(os.path.join(model_path, "snapshots")):
snapshots.append(os.path.join(model_path, "snapshots", d))
latest_snapshot_path = max(snapshots, key=os.path.getmtime)

pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16)
pipe = FluxPipeline.from_pretrained(latest_snapshot_path, torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
pipe.to(torch.float16)

image = pipe(
prompt,
num_inference_steps=4,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cpu").manual_seed(42),
).images[0]

Expand Down
5 changes: 2 additions & 3 deletions tools/flux/flux.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<description>text-to-image model</description>
<macros>
<token name="@TOOL_VERSION@">2024</token>
<token name="@VERSION_SUFFIX@">0</token>
<token name="@VERSION_SUFFIX@">1</token>
</macros>
<requirements>
<requirement type="package" version="3.12">python</requirement>
Expand All @@ -16,9 +16,8 @@
<requirement type="package" version="0.24.6">huggingface_hub</requirement>
</requirements>
<command detect_errors="exit_code"><![CDATA[
export HF_HOME='$flux_models.fields.path' &&
python '$__tool_directory__/flux.py'
'$flux_models'
'$flux_models.fields.path'
'$input_type_selector'
'$prompt'
]]></command>
Expand Down

0 comments on commit 0959e15

Please sign in to comment.