Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ruff ci #194

Merged
merged 5 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: CI
on: push
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.6.8
- name: Run Ruff
run: ruff check --output-format=github .
- name: Check imports
run: ruff check --select I --output-format=github .
- name: Check formatting
run: ruff format --check .
62 changes: 46 additions & 16 deletions demo_gr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import time
import uuid

import torch
import gradio as gr
import numpy as np
import torch
from einops import rearrange
from PIL import Image, ExifTags
from PIL import ExifTags, Image
from transformers import pipeline

from flux.cli import SamplingOptions
Expand All @@ -15,6 +15,7 @@

NSFW_THRESHOLD = 0.85


def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
t5 = load_t5(device, max_length=256 if is_schnell else 512)
clip = load_clip(device)
Expand All @@ -23,6 +24,7 @@ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool)
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
return model, ae, t5, clip, nsfw_classifier


class FluxGenerator:
def __init__(self, model_name: str, device: str, offload: bool):
self.device = torch.device(device)
Expand Down Expand Up @@ -70,7 +72,7 @@ def generate_image(
if init_image is not None:
if isinstance(init_image, np.ndarray):
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0
init_image = init_image.unsqueeze(0)
init_image = init_image.unsqueeze(0)
init_image = init_image.to(self.device)
init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width))
if self.offload:
Expand Down Expand Up @@ -151,37 +153,49 @@ def generate_image(
exif_data[ExifTags.Base.Model] = self.model_name
if add_sampling_metadata:
exif_data[ExifTags.Base.ImageDescription] = prompt

img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)

return img, str(opts.seed), filename, None
else:
return None, str(opts.seed), None, "Your generated image may contain NSFW content."

def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False):

def create_demo(
model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False
):
generator = FluxGenerator(model_name, device, offload)
is_schnell = model_name == "flux-schnell"

with gr.Blocks() as demo:
gr.Markdown(f"# Flux Image Generation Demo - Model: {model_name}")

with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="a photo of a forest with mist swirling around the tree trunks. The word \"FLUX\" is painted over it in big, red brush strokes with visible texture")
prompt = gr.Textbox(
label="Prompt",
value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture',
)
do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell)
init_image = gr.Image(label="Input Image", visible=False)
image2image_strength = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False)

image2image_strength = gr.Slider(
0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False
)

with gr.Accordion("Advanced Options", open=False):
width = gr.Slider(128, 8192, 1360, step=16, label="Width")
height = gr.Slider(128, 8192, 768, step=16, label="Height")
num_steps = gr.Slider(1, 50, 4 if is_schnell else 50, step=1, label="Number of steps")
guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell)
guidance = gr.Slider(
1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell
)
seed = gr.Textbox(-1, label="Seed (-1 for random)")
add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True)

add_sampling_metadata = gr.Checkbox(
label="Add sampling parameters to metadata?", value=True
)

generate_btn = gr.Button("Generate")

with gr.Column():
output_image = gr.Image(label="Generated Image")
seed_output = gr.Number(label="Used Seed")
Expand All @@ -198,17 +212,33 @@ def update_img2img(do_img2img):

generate_btn.click(
fn=generator.generate_image,
inputs=[width, height, num_steps, guidance, seed, prompt, init_image, image2image_strength, add_sampling_metadata],
inputs=[
width,
height,
num_steps,
guidance,
seed,
prompt,
init_image,
image2image_strength,
add_sampling_metadata,
],
outputs=[output_image, seed_output, download_btn, warning_text],
)

return demo


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Flux")
parser.add_argument("--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
parser.add_argument(
"--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name"
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use"
)
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
args = parser.parse_args()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"protobuf",
"requests",
"invisible-watermark",
"ruff == 0.6.8",
]

[project.optional-dependencies]
Expand Down
6 changes: 4 additions & 2 deletions src/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
try:
from ._version import version as __version__ # type: ignore
from ._version import version_tuple
from ._version import (
version as __version__, # type: ignore
version_tuple,
)
except ImportError:
__version__ = "unknown (no version information available)"
version_tuple = (0, 0, "unknown", "noinfo")
Expand Down
31 changes: 7 additions & 24 deletions src/flux/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,14 @@ def __init__(
elif interval is not None and not (1.0 <= interval <= 4.0):
raise ValueError(f"interval must be between 1 and 4, got {interval}")
elif safety_tolerance is not None and not (0 <= safety_tolerance <= 6.0):
raise ValueError(
f"safety_tolerance must be between 0 and 6, got {interval}"
)
raise ValueError(f"safety_tolerance must be between 0 and 6, got {interval}")

if name == "flux.1-dev":
if interval is not None:
raise ValueError("Interval is not supported for flux.1-dev")
if name == "flux.1.1-pro":
if (
interval is not None
or num_steps is not None
or guidance is not None
):
raise ValueError(
"Interval, num_steps and guidance are not supported for "
"flux.1.1-pro"
)
if interval is not None or num_steps is not None or guidance is not None:
raise ValueError("Interval, num_steps and guidance are not supported for " "flux.1.1-pro")

self.name = name
self.request_json = {
Expand All @@ -124,9 +115,7 @@ def __init__(
"interval": interval,
"safety_tolerance": safety_tolerance,
}
self.request_json = {
key: value for key, value in self.request_json.items() if value is not None
}
self.request_json = {key: value for key, value in self.request_json.items() if value is not None}

self.request_id: str | None = None
self.result: dict | None = None
Expand Down Expand Up @@ -157,9 +146,7 @@ def request(self):
)
result = response.json()
if response.status_code != 200:
raise ApiException(
status_code=response.status_code, detail=result.get("detail")
)
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
self.request_id = response.json()["id"]

def retrieve(self) -> dict:
Expand All @@ -181,17 +168,13 @@ def retrieve(self) -> dict:
)
result = response.json()
if "status" not in result:
raise ApiException(
status_code=response.status_code, detail=result.get("detail")
)
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
elif result["status"] == "Ready":
self.result = result["result"]
elif result["status"] == "Pending":
time.sleep(0.5)
else:
raise ApiException(
status_code=200, detail=f"API returned status '{result['status']}'"
)
raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
return self.result

@property
Expand Down
8 changes: 4 additions & 4 deletions src/flux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from einops import rearrange
from fire import Fire
from PIL import ExifTags, Image
from transformers import pipeline

from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.util import (configs, embed_watermark, load_ae, load_clip,
load_flow_model, load_t5)
from transformers import pipeline
from flux.util import configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5

NSFW_THRESHOLD = 0.85


@dataclass
class SamplingOptions:
prompt: str
Expand Down Expand Up @@ -229,7 +229,7 @@ def main(

img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]

if nsfw_score < NSFW_THRESHOLD:
exif_data = Image.Exif()
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
Expand Down
11 changes: 8 additions & 3 deletions src/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import torch
from torch import Tensor, nn

from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)


@dataclass
Expand Down
3 changes: 1 addition & 2 deletions src/flux/modules/conditioner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from torch import Tensor, nn
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
T5Tokenizer)
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer


class HFEmbedder(nn.Module):
Expand Down
Loading