From a68647115842dfbb84366adf2c369dd11bb08c5e Mon Sep 17 00:00:00 2001 From: Tim Dockhorn Date: Thu, 14 Nov 2024 12:44:34 -0800 Subject: [PATCH 1/5] apply ruff --- .github/workflows/vibe.yaml | 16 +++++++++ demo_gr.py | 62 ++++++++++++++++++++++++--------- pyproject.toml | 1 + src/flux/__init__.py | 6 ++-- src/flux/api.py | 31 ++++------------- src/flux/cli.py | 8 ++--- src/flux/model.py | 11 ++++-- src/flux/modules/conditioner.py | 3 +- 8 files changed, 87 insertions(+), 51 deletions(-) create mode 100644 .github/workflows/vibe.yaml diff --git a/.github/workflows/vibe.yaml b/.github/workflows/vibe.yaml new file mode 100644 index 00000000..87ddd297 --- /dev/null +++ b/.github/workflows/vibe.yaml @@ -0,0 +1,16 @@ +name: CI +on: push +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff + - name: Run Ruff + run: ruff check --output-format=github . diff --git a/demo_gr.py b/demo_gr.py index 03d9909b..3b4d022b 100644 --- a/demo_gr.py +++ b/demo_gr.py @@ -2,11 +2,11 @@ import time import uuid -import torch import gradio as gr import numpy as np +import torch from einops import rearrange -from PIL import Image, ExifTags +from PIL import ExifTags, Image from transformers import pipeline from flux.cli import SamplingOptions @@ -15,6 +15,7 @@ NSFW_THRESHOLD = 0.85 + def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool): t5 = load_t5(device, max_length=256 if is_schnell else 512) clip = load_clip(device) @@ -23,6 +24,7 @@ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool) nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) return model, ae, t5, clip, nsfw_classifier + class FluxGenerator: def __init__(self, model_name: str, device: str, offload: bool): self.device = torch.device(device) @@ -70,7 +72,7 @@ def generate_image( if init_image is not None: if isinstance(init_image, np.ndarray): init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0 - init_image = init_image.unsqueeze(0) + init_image = init_image.unsqueeze(0) init_image = init_image.to(self.device) init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width)) if self.offload: @@ -151,37 +153,49 @@ def generate_image( exif_data[ExifTags.Base.Model] = self.model_name if add_sampling_metadata: exif_data[ExifTags.Base.ImageDescription] = prompt - + img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) return img, str(opts.seed), filename, None else: return None, str(opts.seed), None, "Your generated image may contain NSFW content." -def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False): + +def create_demo( + model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False +): generator = FluxGenerator(model_name, device, offload) is_schnell = model_name == "flux-schnell" with gr.Blocks() as demo: gr.Markdown(f"# Flux Image Generation Demo - Model: {model_name}") - + with gr.Row(): with gr.Column(): - prompt = gr.Textbox(label="Prompt", value="a photo of a forest with mist swirling around the tree trunks. The word \"FLUX\" is painted over it in big, red brush strokes with visible texture") + prompt = gr.Textbox( + label="Prompt", + value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture', + ) do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell) init_image = gr.Image(label="Input Image", visible=False) - image2image_strength = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False) - + image2image_strength = gr.Slider( + 0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False + ) + with gr.Accordion("Advanced Options", open=False): width = gr.Slider(128, 8192, 1360, step=16, label="Width") height = gr.Slider(128, 8192, 768, step=16, label="Height") num_steps = gr.Slider(1, 50, 4 if is_schnell else 50, step=1, label="Number of steps") - guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell) + guidance = gr.Slider( + 1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell + ) seed = gr.Textbox(-1, label="Seed (-1 for random)") - add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True) - + add_sampling_metadata = gr.Checkbox( + label="Add sampling parameters to metadata?", value=True + ) + generate_btn = gr.Button("Generate") - + with gr.Column(): output_image = gr.Image(label="Generated Image") seed_output = gr.Number(label="Used Seed") @@ -198,17 +212,33 @@ def update_img2img(do_img2img): generate_btn.click( fn=generator.generate_image, - inputs=[width, height, num_steps, guidance, seed, prompt, init_image, image2image_strength, add_sampling_metadata], + inputs=[ + width, + height, + num_steps, + guidance, + seed, + prompt, + init_image, + image2image_strength, + add_sampling_metadata, + ], outputs=[output_image, seed_output, download_btn, warning_text], ) return demo + if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser(description="Flux") - parser.add_argument("--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name") - parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") + parser.add_argument( + "--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name" + ) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use" + ) parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") parser.add_argument("--share", action="store_true", help="Create a public link to your demo") args = parser.parse_args() diff --git a/pyproject.toml b/pyproject.toml index 72f921b4..2c3692e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "protobuf", "requests", "invisible-watermark", + "ruff == 0.6.8", ] [project.optional-dependencies] diff --git a/src/flux/__init__.py b/src/flux/__init__.py index 43c365a4..dddc6a38 100644 --- a/src/flux/__init__.py +++ b/src/flux/__init__.py @@ -1,6 +1,8 @@ try: - from ._version import version as __version__ # type: ignore - from ._version import version_tuple + from ._version import ( + version as __version__, # type: ignore + version_tuple, + ) except ImportError: __version__ = "unknown (no version information available)" version_tuple = (0, 0, "unknown", "noinfo") diff --git a/src/flux/api.py b/src/flux/api.py index ae60d91a..6a608840 100644 --- a/src/flux/api.py +++ b/src/flux/api.py @@ -94,23 +94,14 @@ def __init__( elif interval is not None and not (1.0 <= interval <= 4.0): raise ValueError(f"interval must be between 1 and 4, got {interval}") elif safety_tolerance is not None and not (0 <= safety_tolerance <= 6.0): - raise ValueError( - f"safety_tolerance must be between 0 and 6, got {interval}" - ) + raise ValueError(f"safety_tolerance must be between 0 and 6, got {interval}") if name == "flux.1-dev": if interval is not None: raise ValueError("Interval is not supported for flux.1-dev") if name == "flux.1.1-pro": - if ( - interval is not None - or num_steps is not None - or guidance is not None - ): - raise ValueError( - "Interval, num_steps and guidance are not supported for " - "flux.1.1-pro" - ) + if interval is not None or num_steps is not None or guidance is not None: + raise ValueError("Interval, num_steps and guidance are not supported for " "flux.1.1-pro") self.name = name self.request_json = { @@ -124,9 +115,7 @@ def __init__( "interval": interval, "safety_tolerance": safety_tolerance, } - self.request_json = { - key: value for key, value in self.request_json.items() if value is not None - } + self.request_json = {key: value for key, value in self.request_json.items() if value is not None} self.request_id: str | None = None self.result: dict | None = None @@ -157,9 +146,7 @@ def request(self): ) result = response.json() if response.status_code != 200: - raise ApiException( - status_code=response.status_code, detail=result.get("detail") - ) + raise ApiException(status_code=response.status_code, detail=result.get("detail")) self.request_id = response.json()["id"] def retrieve(self) -> dict: @@ -181,17 +168,13 @@ def retrieve(self) -> dict: ) result = response.json() if "status" not in result: - raise ApiException( - status_code=response.status_code, detail=result.get("detail") - ) + raise ApiException(status_code=response.status_code, detail=result.get("detail")) elif result["status"] == "Ready": self.result = result["result"] elif result["status"] == "Pending": time.sleep(0.5) else: - raise ApiException( - status_code=200, detail=f"API returned status '{result['status']}'" - ) + raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") return self.result @property diff --git a/src/flux/cli.py b/src/flux/cli.py index a2206f0c..a6f04d21 100644 --- a/src/flux/cli.py +++ b/src/flux/cli.py @@ -8,14 +8,14 @@ from einops import rearrange from fire import Fire from PIL import ExifTags, Image +from transformers import pipeline from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack -from flux.util import (configs, embed_watermark, load_ae, load_clip, - load_flow_model, load_t5) -from transformers import pipeline +from flux.util import configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5 NSFW_THRESHOLD = 0.85 + @dataclass class SamplingOptions: prompt: str @@ -229,7 +229,7 @@ def main( img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] - + if nsfw_score < NSFW_THRESHOLD: exif_data = Image.Exif() exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" diff --git a/src/flux/model.py b/src/flux/model.py index f33ab832..75a681f3 100644 --- a/src/flux/model.py +++ b/src/flux/model.py @@ -3,9 +3,14 @@ import torch from torch import Tensor, nn -from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, - MLPEmbedder, SingleStreamBlock, - timestep_embedding) +from flux.modules.layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) @dataclass diff --git a/src/flux/modules/conditioner.py b/src/flux/modules/conditioner.py index 7cdd8818..e60297e4 100644 --- a/src/flux/modules/conditioner.py +++ b/src/flux/modules/conditioner.py @@ -1,6 +1,5 @@ from torch import Tensor, nn -from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, - T5Tokenizer) +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer class HFEmbedder(nn.Module): From cda68fb2cba493e739ac6b9b62941c309d72dadc Mon Sep 17 00:00:00 2001 From: Tim Dockhorn Date: Thu, 14 Nov 2024 12:44:58 -0800 Subject: [PATCH 2/5] rename --- .github/workflows/{vibe.yaml => ci.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename .github/workflows/{vibe.yaml => ci.yaml} (100%) diff --git a/.github/workflows/vibe.yaml b/.github/workflows/ci.yaml similarity index 100% rename from .github/workflows/vibe.yaml rename to .github/workflows/ci.yaml From 050d00cfd3cc98b29916e175d22f7c1db21ee95a Mon Sep 17 00:00:00 2001 From: Tim Dockhorn Date: Thu, 14 Nov 2024 12:49:37 -0800 Subject: [PATCH 3/5] specify ruff version for CI --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 87ddd297..e9b94a10 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -11,6 +11,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff + pip install ruff==0.6.8 - name: Run Ruff run: ruff check --output-format=github . From 3db5b276ab2cda1c3d9d30ab0b8b884cf75f2cd0 Mon Sep 17 00:00:00 2001 From: Tim Dockhorn Date: Thu, 14 Nov 2024 12:55:31 -0800 Subject: [PATCH 4/5] also check imports --- .github/workflows/ci.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e9b94a10..fde6ad9a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -14,3 +14,5 @@ jobs: pip install ruff==0.6.8 - name: Run Ruff run: ruff check --output-format=github . + - name: Check imports + run: ruff check --select I --output-format=github . From 517f838703663b6a9920a81d32e004136bec3f8f Mon Sep 17 00:00:00 2001 From: Tim Dockhorn Date: Thu, 14 Nov 2024 12:59:30 -0800 Subject: [PATCH 5/5] check formatting --- .github/workflows/ci.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fde6ad9a..8219c61b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -16,3 +16,5 @@ jobs: run: ruff check --output-format=github . - name: Check imports run: ruff check --select I --output-format=github . + - name: Check formatting + run: ruff format --check .