diff --git a/.eslintrc.json b/.eslintrc.json index 62feb13a5..c86dbb749 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -37,14 +37,19 @@ "object-curly-newline":"off", "prefer-rest-params":"off", "prefer-destructuring":"off", - "radix":"off" + "radix":"off", + "node/shebang": "off" }, "globals": { // asssets "panzoom": "readonly", - // script.js + // logger.js "log": "readonly", "debug": "readonly", + "error": "readonly", + "xhrGet": "readonly", + "xhrPost": "readonly", + // script.js "gradioApp": "readonly", "executeCallbacks": "readonly", "onAfterUiUpdate": "readonly", @@ -87,7 +92,6 @@ // settings.js "registerDragDrop": "readonly", // extraNetworks.js - "requestGet": "readonly", "getENActiveTab": "readonly", "quickApplyStyle": "readonly", "quickSaveStyle": "readonly", diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index cf176d6cf..a40320c63 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -106,10 +106,14 @@ body: - StableDiffusion 1.5 - StableDiffusion 2.1 - StableDiffusion XL - - StableDiffusion 3 - - PixArt + - StableDiffusion 3.x - StableCascade + - FLUX.1 + - PixArt - Kandinsky + - Playground + - AuraFlow + - Any Video Model - Other default: 0 validations: diff --git a/.pylintrc b/.pylintrc index 59f1cb127..ad42ddd13 100644 --- a/.pylintrc +++ b/.pylintrc @@ -13,6 +13,7 @@ ignore-paths=/usr/lib/.*$, modules/control/units, modules/ctrlx, modules/dml, + modules/freescale, modules/ggml, modules/hidiffusion, modules/hijack, diff --git a/.ruff.toml b/.ruff.toml index c2d4a6f9a..4bab64260 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -7,6 +7,7 @@ exclude = [ "modules/consistory", "modules/control/proc", "modules/control/units", + "modules/freescale", "modules/ggml", "modules/hidiffusion", "modules/hijack", diff --git a/CHANGELOG.md b/CHANGELOG.md index bc6cd163b..b778dca21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,21 +1,230 @@ # Change Log for SD.Next -## Update for 2024-11-22 - -- Model loader improvements: +## Update for 2024-12-24 + +### Highlights for 2024-12-24 + +### SD.Next Xmass edition: *What's new?* + +While we have several new supported models, workflows and tools, this release is primarily about *quality-of-life improvements*: +- New memory management engine + list of changes that went into this one is long: changes to GPU offloading, brand new LoRA loader, system memory management, on-the-fly quantization, improved gguf loader, etc. + but main goal is enabling modern large models to run on standard consumer GPUs + without performance hits typically associated with aggressive memory swapping and needs for constant manual tweaks +- New [documentation website](https://vladmandic.github.io/sdnext-docs/) + with full search and tons of new documentation +- New settings panel with simplified and streamlined configuration + +We've also added support for several new models such as highly anticipated [NVLabs Sana](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px) (see [supported models](https://vladmandic.github.io/sdnext-docs/Model-Support/) for full list) +And several new SOTA video models: [Lightricks LTX-Video](https://huggingface.co/Lightricks/LTX-Video), [Hunyuan Video](https://huggingface.co/tencent/HunyuanVideo) and [Genmo Mochi.1 Preview](https://huggingface.co/genmo/mochi-1-preview) + +And a lot of **Control** and **IPAdapter** goodies +- for **SDXL** there is new [ProMax](https://huggingface.co/xinsir/controlnet-union-sdxl-1.0), improved *Union* and *Tiling* models +- for **FLUX.1** there are [Flux Tools](https://blackforestlabs.ai/flux-1-tools/) as well as official *Canny* and *Depth* models, + a cool [Redux](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) model as well as [XLabs](https://huggingface.co/XLabs-AI/flux-ip-adapter-v2) IP-adapter +- for **SD3.5** there are official *Canny*, *Blur* and *Depth* models in addition to existing 3rd party models + as well as [InstantX](https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter) IP-adapter + +Plus couple of new integrated workflows such as [FreeScale](https://github.com/ali-vilab/FreeScale) and [Style Aligned Image Generation](https://style-aligned-gen.github.io/) + +And it wouldn't be a *Xmass edition* without couple of custom themes: *Snowflake* and *Elf-Green*! +All-in-all, we're around ~180 commits worth of updates, check the changelog for full list + +[ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) + +## Details for 2024-12-24 + +### New models and integrations + +- [NVLabs Sana](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px) + support for 1.6B 2048px, 1.6B 1024px and 0.6B 512px models + **Sana** can synthesize high-resolution images with strong text-image alignment by using **Gemma2** as text-encoder + and its *fast* - typically at least **2x** faster than sd-xl even for 1.6B variant and maintains performance regardless of resolution + e.g., rendering at 4k is possible in less than 8GB vram + to use, select from *networks -> models -> reference* and models will be auto-downloaded on first use + *reference values*: sampler: default (or any flow-match variant), steps: 20, width/height: 1024, guidance scale: 4.5 + *note* like other LLM-based text-encoders, sana prefers long and descriptive prompts + any short prompt below 300 characters will be auto-expanded using built in Gemma LLM before encoding while long prompts will be passed as-is +- **ControlNet** + - improved support for **Union** controlnets with granular control mode type + - added support for latest [Xinsir ProMax](https://huggingface.co/xinsir/controlnet-union-sdxl-1.0) all-in-one controlnet + - added support for multiple **Tiling** controlnets, for example [Xinsir Tile](https://huggingface.co/xinsir/controlnet-tile-sdxl-1.0) + *note*: when selecting tiles in control settings, you can also specify non-square ratios + in which case it will use context-aware image resize to maintain overall composition + *note*: available tiling options can be set in settings -> control +- **IP-Adapter** + - FLUX.1 [XLabs](https://huggingface.co/XLabs-AI/flux-ip-adapter-v2) v1 and v2 IP-adapter + - FLUX.1 secondary guidance, enabled using *Attention guidance* in advanced menu + - SD 3.5 [InstantX](https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter) IP-adapter +- [Flux Tools](https://blackforestlabs.ai/flux-1-tools/) + **Redux** is actually a tool, **Fill** is inpaint/outpaint optimized version of *Flux-dev* + **Canny** & **Depth** are optimized versions of *Flux-dev* for their respective tasks: they are *not* ControlNets that work on top of a model + to use, go to image or control interface and select *Flux Tools* in scripts + all models are auto-downloaded on first use + *note*: All models are [gated](https://github.com/vladmandic/automatic/wiki/Gated) and require acceptance of terms and conditions via web page + *recommended*: Enable on-the-fly [quantization](https://github.com/vladmandic/automatic/wiki/Quantization) or [compression](https://github.com/vladmandic/automatic/wiki/NNCF-Compression) to reduce resource usage + *todo*: support for Canny/Depth LoRAs + - [Redux](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev): ~0.1GB + works together with existing model and basically uses input image to analyze it and use that instead of prompt + *optional* can use prompt to combine guidance with input image + *recommended*: low denoise strength levels result in more variety + - [Fill](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev): ~23.8GB, replaces currently loaded model + *note*: can be used in inpaint/outpaint mode only + - [Canny](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev): ~23.8GB, replaces currently loaded model + *recommended*: guidance scale 30 + - [Depth](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev): ~23.8GB, replaces currently loaded model + *recommended*: guidance scale 10 +- [Flux ControlNet LoRA](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) + alternative to standard ControlNets, FLUX.1 also allows LoRA to help guide the generation process + both **Depth** and **Canny** LoRAs are available in standard control menus +- [StabilityAI SD35 ControlNets](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets) + - In addition to previously released `InstantX` and `Alimama`, we now have *official* ones from StabilityAI +- [Style Aligned Image Generation](https://style-aligned-gen.github.io/) + enable in scripts, compatible with sd-xl + enter multiple prompts in prompt field separated by new line + style-aligned applies selected attention layers uniformly to all images to achive consistency + can be used with or without input image in which case first prompt is used to establish baseline + *note:* all prompts are processes as a single batch, so vram is limiting factor +- [FreeScale](https://github.com/ali-vilab/FreeScale) + enable in scripts, compatible with sd-xl for text and img2img + run iterative generation of images at different scales to achieve better results + can render 4k sdxl images + *note*: disable live preview to avoid memory issues when generating large images + +### Video models + +- [Lightricks LTX-Video](https://huggingface.co/Lightricks/LTX-Video) + model size: 27.75gb + support for 0.9.0, 0.9.1 and custom safetensor-based models with full quantization and offloading support + support for text-to-video and image-to-video, to use, select in *scripts -> ltx-video* + *refrence values*: steps 50, width 704, height 512, frames 161, guidance scale 3.0 +- [Hunyuan Video](https://huggingface.co/tencent/HunyuanVideo) + model size: 40.92gb + support for text-to-video, to use, select in *scripts -> hunyuan video* + basic support only + *refrence values*: steps 50, width 1280, height 720, frames 129, guidance scale 6.0 +- [Genmo Mochi.1 Preview](https://huggingface.co/genmo/mochi-1-preview) + support for text-to-video, to use, select in *scripts -> mochi.1 video* + basic support only + *refrence values*: steps 64, width 848, height 480, frames 19, guidance scale 4.5 + +*Notes*: +- all video models are very large and resource intensive! + any use on gpus below 16gb and systems below 48gb ram is experimental at best +- sdnext support for video models is relatively basic with further optimizations pending community interest + any future optimizations would likely have to go into partial loading and excecution instead of offloading inactive parts of the model +- new video models use generic llms for prompting and due to that requires very long and descriptive prompt +- you may need to enable sequential offload for maximum gpu memory savings +- optionally enable pre-quantization using bnb for additional memory savings +- reduce number of frames and/or resolution to reduce memory usage + +### UI and workflow improvements + +- **Docs**: + - New documentation site! + - Additional Wiki content: Styles, Wildcards, etc. +- **LoRA** handler rewrite: + - LoRA weights are no longer calculated on-the-fly during model execution, but are pre-calculated at the start + this results in perceived overhead on generate startup, but results in overall faster execution as LoRA does not need to be processed on each step + thanks @AI-Casanova + - LoRA weights can be applied/unapplied as on each generate or they can store weights backups for later use + this setting has large performance and resource implications, see [Offload](https://github.com/vladmandic/automatic/wiki/Offload) wiki for details + - LoRA name in prompt can now also be an absolute path to a LoRA file, even if LoRA is not indexed + example: `` + - LoRA name in prompt can now also be path to a LoRA file op `huggingface` + example: `` +- **Model loader** improvements: - detect model components on model load fail + - allow passing absolute path to model loader - Flux, SD35: force unload model - Flux: apply `bnb` quant when loading *unet/transformer* - Flux: all-in-one safetensors example: - Flux: do not recast quants -- Sampler improvements - - update DPM FlowMatch samplers -- Fixes: - - update `diffusers` - - fix README links - - fix sdxl controlnet single-file loader - - relax settings validator +- **Memory** improvements: + - faster and more compatible *balanced offload* mode + - balanced offload: units are now in percentage instead of bytes + - balanced offload: add both high and low watermark, defaults as below + `0.25` for low-watermark: skip offload if memory usage is below 25% + `0.70` high-watermark: must offload if memory usage is above 70% + - balanced offload will attempt to run offload as non-blocking and force gc at the end + - change-in-behavior: + low-end systems, triggered by either `lowvrwam` or by detection of <=4GB will use *sequential offload* + all other systems use *balanced offload* by default (can be changed in settings) + previous behavior was to use *model offload* on systems with <=8GB and `medvram` and no offload by default + - VAE upcase is now disabled by default on all systems + if you have issues with image decode, you'll need to enable it manually +- **UI**: + - improved stats on generate completion + - improved live preview display and performance + - improved accordion behavior + - auto-size networks height for sidebar + - control: hide preview column by default + - control: optionn to hide input column + - control: add stats + - settings: reorganized and simplified + - browser -> server logging framework + - add addtional themes: `black-reimagined`, thanks @Artheriax +- **Batch** + - image batch processing will use caption files if they exist instead of default prompt + +### Updates + +- **Quantization** + - Add `TorchAO` *pre* (during load) and *post* (during execution) quantization + **torchao** supports 4 different int-based and 3 float-based quantization schemes + This is in addition to existing support for: + - `BitsAndBytes` with 3 float-based quantization schemes + - `Optimium.Quanto` with 3 int-based and 2 float-based quantizations schemes + - `GGUF` with pre-quantized weights + - Switch `GGUF` loader from custom to diffuser native +- **IPEX**: update to IPEX 2.5.10+xpu +- **OpenVINO**: + - update to 2024.6.0 + - disable model caching by default +- **Sampler** improvements + - UniPC, DEIS, SA, DPM-Multistep: allow FlowMatch sigma method and prediction type + - Euler FlowMatch: add sigma methods (*karras/exponential/betas*) + - Euler FlowMatch: allow using timestep presets to set sigmas + - DPM FlowMatch: update all and add sigma methods + - BDIA-DDIM: *experimental* new scheduler + - UFOGen: *experimental* new scheduler + +### Fixes + +- add `SD_NO_CACHE=true` env variable to disable file/folder caching +- add settings -> networks -> embeddings -> enable/disable +- update `diffusers` +- fix README links +- fix sdxl controlnet single-file loader +- relax settings validator +- improve js progress calls resiliency +- fix text-to-video pipeline +- avoid live-preview if vae-decode is running +- allow xyz-grid with multi-axis s&r +- fix xyz-grid with lora +- fix api script callbacks +- fix gpu memory monitoring +- simplify img2img/inpaint/sketch canvas handling +- fix prompt caching +- fix xyz grid skip final pass +- fix sd upscale script +- fix cogvideox-i2v +- lora auto-apply tags remove duplicates +- control load model on-demand if not already loaded +- taesd limit render to 2024px +- taesd downscale preview to 1024px max: configurable in settings -> live preview +- uninstall conflicting `wandb` package +- dont skip diffusers version check if quick is specified +- notify on torch install +- detect pipeline fro diffusers folder-style model +- do not recast flux quants +- fix xyz-grid with lora none +- fix svd image2video +- fix gallery display during generate +- fix wildcards replacement to be unique +- fix animatediff-xl +- fix pag with batch count ## Update for 2024-11-21 @@ -270,7 +479,7 @@ A month later and with nearly 300 commits, here is the latest [SD.Next](https:// #### New models for 2024-10-23 -- New fine-tuned [CLiP-ViT-L]((https://huggingface.co/zer0int/CLIP-GmP-ViT-L-14)) 1st stage **text-encoders** used by most models (SD15/SDXL/SD3/Flux/etc.) brings additional details to your images +- New fine-tuned [CLiP-ViT-L](https://huggingface.co/zer0int/CLIP-GmP-ViT-L-14) 1st stage **text-encoders** used by most models (SD15/SDXL/SD3/Flux/etc.) brings additional details to your images - New models: [Stable Diffusion 3.5 Large](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) [OmniGen](https://arxiv.org/pdf/2409.11340) @@ -370,7 +579,7 @@ And there are also other goodies like multiple *XYZ grid* improvements, addition - xyz grid support for sampler options - metadata updates for sampler options - modernui updates for sampler options - - *note* sampler options defaults are not save in ui settings, they are saved in server settings + - *note* sampler options defaults are not saved in ui settings, they are saved in server settings to apply your defaults, set ui values and apply via *system -> settings -> apply settings* *sampler options*: @@ -602,7 +811,7 @@ Examples: - vae is list of manually downloaded safetensors - text-encoder is list of predefined and manually downloaded text-encoders - **controlnet** support: - support for **InstantX/Shakker-Labs** models including [Union-Pro](InstantX/FLUX.1-dev-Controlnet-Union) + support for **InstantX/Shakker-Labs** models including [Union-Pro](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union) note that flux controlnet models are large, up to 6.6GB on top of already large base model! as such, you may need to use offloading:sequential which is not as fast, but uses far less memory when using union model, you must also select control mode in the control unit @@ -2117,7 +2326,7 @@ Also new is support for **SDXL-Turbo** as well as new **Kandinsky 3** models and - in *Advanced* params - allows control of *latent clamping*, *color centering* and *range maximization* - supported by *XYZ grid* - - [SD21 Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL Turbo]() support + - [SD21 Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support - just set CFG scale (0.0-1.0) and steps (1-3) to a very low value - compatible with original StabilityAI SDXL-Turbo or any of the newer merges - download safetensors or select from networks -> reference diff --git a/README.md b/README.md index 1644281ad..2a95373d0 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
-SD.Next +SD.Next **Image Diffusion implementation with advanced features** @@ -8,15 +8,16 @@ [![Discord](https://img.shields.io/discord/1101998836328697867?logo=Discord&svg=true)](https://discord.gg/VjvR2tabEX) [![Sponsors](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/vladmandic) -[Wiki](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.gg/VjvR2tabEX) | [Changelog](CHANGELOG.md) +[Docs](https://vladmandic.github.io/sdnext-docs/) | [Wiki](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.gg/VjvR2tabEX) | [Changelog](CHANGELOG.md)

## Table of contents +- [Documentation](https://vladmandic.github.io/sdnext-docs/) - [SD.Next Features](#sdnext-features) -- [Model support](#model-support) +- [Model support](#model-support) and [Specifications]() - [Platform support](#platform-support) - [Getting started](#getting-started) @@ -25,7 +26,7 @@ All individual features are not listed here, instead check [ChangeLog](CHANGELOG.md) for full list of changes - Multiple UIs! ▹ **Standard | Modern** -- Multiple diffusion models! +- Multiple [diffusion models](https://vladmandic.github.io/sdnext-docs/Model-Support/)! - Built-in Control for Text, Image, Batch and video processing! - Multiplatform! ▹ **Windows | Linux | MacOS | nVidia | AMD | IntelArc/IPEX | DirectML | OpenVINO | ONNX+Olive | ZLUDA** @@ -34,9 +35,7 @@ All individual features are not listed here, instead check [ChangeLog](CHANGELOG - Platform specific autodetection and tuning performed on install - Optimized processing with latest `torch` developments with built-in support for `torch.compile` and multiple compile backends: *Triton, ZLUDA, StableFast, DeepCache, OpenVINO, NNCF, IPEX, OneDiff* -- Improved prompt parser - Built-in queue management -- Enterprise level logging and hardened API - Built in installer with automatic updates and dependency management - Mobile compatible @@ -49,42 +48,13 @@ All individual features are not listed here, instead check [ChangeLog](CHANGELOG ![screenshot-modernui](https://github.com/user-attachments/assets/39e3bc9a-a9f7-4cda-ba33-7da8def08032) -For screenshots and informations on other available themes, see [Themes Wiki](https://github.com/vladmandic/automatic/wiki/Themes) +For screenshots and informations on other available themes, see [Themes](https://vladmandic.github.io/sdnext-docs/Themes/)
## Model support -Additional models will be added as they become available and there is public interest in them -See [models overview](https://github.com/vladmandic/automatic/wiki/Models) for details on each model, including their architecture, complexity and other info - -- [RunwayML Stable Diffusion](https://github.com/Stability-AI/stablediffusion/) 1.x and 2.x *(all variants)* -- [StabilityAI Stable Diffusion XL](https://github.com/Stability-AI/generative-models), [StabilityAI Stable Diffusion 3.0](https://stability.ai/news/stable-diffusion-3-medium) Medium, [StabilityAI Stable Diffusion 3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) Medium, Large, Large Turbo -- [StabilityAI Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid) Base, XT 1.0, XT 1.1 -- [StabilityAI Stable Cascade](https://github.com/Stability-AI/StableCascade) *Full* and *Lite* -- [Black Forest Labs FLUX.1](https://blackforestlabs.ai/announcing-black-forest-labs/) Dev, Schnell -- [AuraFlow](https://huggingface.co/fal/AuraFlow) -- [AlphaVLLM Lumina-Next-SFT](https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT-diffusers) -- [Playground AI](https://huggingface.co/playgroundai/playground-v2-256px-base) *v1, v2 256, v2 512, v2 1024 and latest v2.5* -- [Tencent HunyuanDiT](https://github.com/Tencent/HunyuanDiT) -- [OmniGen](https://arxiv.org/pdf/2409.11340) -- [Meissonic](https://github.com/viiika/Meissonic) -- [Kwai Kolors](https://huggingface.co/Kwai-Kolors/Kolors) -- [CogView 3+](https://huggingface.co/THUDM/CogView3-Plus-3B) -- [LCM: Latent Consistency Models](https://github.com/openai/consistency_models) -- [aMUSEd](https://huggingface.co/amused/amused-256) 256 and 512 -- [Segmind Vega](https://huggingface.co/segmind/Segmind-Vega), [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B), [Segmind SegMoE](https://github.com/segmind/segmoe) *SD and SD-XL*, [Segmind SD Distilled](https://huggingface.co/blog/sd_distillation) *(all variants)* -- [Kandinsky](https://github.com/ai-forever/Kandinsky-2) *2.1 and 2.2 and latest 3.0* -- [PixArt-α XL 2](https://github.com/PixArt-alpha/PixArt-alpha) *Medium and Large*, [PixArt-Σ](https://github.com/PixArt-alpha/PixArt-sigma) -- [Warp Wuerstchen](https://huggingface.co/blog/wuertschen) -- [Tsinghua UniDiffusion](https://github.com/thu-ml/unidiffuser) -- [DeepFloyd IF](https://github.com/deep-floyd/IF) *Medium and Large* -- [ModelScope T2V](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b) -- [BLIP-Diffusion](https://dxli94.github.io/BLIP-Diffusion-website/) -- [KOALA 700M](https://github.com/youngwanLEE/sdxl-koala) -- [VGen](https://huggingface.co/ali-vilab/i2vgen-xl) -- [SDXS](https://github.com/IDKiro/sdxs) -- [Hyper-SD](https://huggingface.co/ByteDance/Hyper-SD) +SD.Next supports broad range of models: [supported models](https://vladmandic.github.io/sdnext-docs/Model-Support/) and [model specs](https://vladmandic.github.io/sdnext-docs/Models/) ## Platform support @@ -97,47 +67,29 @@ See [models overview](https://github.com/vladmandic/automatic/wiki/Models) for d - Any GPU or device compatible with **OpenVINO** libraries on both *Windows and Linux* - *Apple M1/M2* on *OSX* using built-in support in Torch with **MPS** optimizations - *ONNX/Olive* -- *AMD* GPUs on Windows using **ZLUDA** libraries +- *AMD* GPUs on Windows using **ZLUDA** libraries ## Getting started -- Get started with **SD.Next** by following the [installation instructions](https://github.com/vladmandic/automatic/wiki/Installation) -- For more details, check out [advanced installation](https://github.com/vladmandic/automatic/wiki/Advanced-Install) guide -- List and explanation of [command line arguments](https://github.com/vladmandic/automatic/wiki/CLI-Arguments) +- Get started with **SD.Next** by following the [installation instructions](https://vladmandic.github.io/sdnext-docs/Installation/) +- For more details, check out [advanced installation](https://vladmandic.github.io/sdnext-docs/Advanced-Install/) guide +- List and explanation of [command line arguments](https://vladmandic.github.io/sdnext-docs/CLI-Arguments/) - Install walkthrough [video](https://www.youtube.com/watch?v=nWTnTyFTuAs) > [!TIP] > And for platform specific information, check out -> [WSL](https://github.com/vladmandic/automatic/wiki/WSL) | [Intel Arc](https://github.com/vladmandic/automatic/wiki/Intel-ARC) | [DirectML](https://github.com/vladmandic/automatic/wiki/DirectML) | [OpenVINO](https://github.com/vladmandic/automatic/wiki/OpenVINO) | [ONNX & Olive](https://github.com/vladmandic/automatic/wiki/ONNX-Runtime) | [ZLUDA](https://github.com/vladmandic/automatic/wiki/ZLUDA) | [AMD ROCm](https://github.com/vladmandic/automatic/wiki/AMD-ROCm) | [MacOS](https://github.com/vladmandic/automatic/wiki/MacOS-Python.md) | [nVidia](https://github.com/vladmandic/automatic/wiki/nVidia) +> [WSL](https://vladmandic.github.io/sdnext-docs/WSL/) | [Intel Arc](https://vladmandic.github.io/sdnext-docs/Intel-ARC/) | [DirectML](https://vladmandic.github.io/sdnext-docs/DirectML/) | [OpenVINO](https://vladmandic.github.io/sdnext-docs/OpenVINO/) | [ONNX & Olive](https://vladmandic.github.io/sdnext-docs/ONNX-Runtime/) | [ZLUDA](https://vladmandic.github.io/sdnext-docs/ZLUDA/) | [AMD ROCm](https://vladmandic.github.io/sdnext-docs/AMD-ROCm/) | [MacOS](https://vladmandic.github.io/sdnext-docs/MacOS-Python/) | [nVidia](https://vladmandic.github.io/sdnext-docs/nVidia/) | [Docker](https://vladmandic.github.io/sdnext-docs/Docker/) > [!WARNING] -> If you run into issues, check out [troubleshooting](https://github.com/vladmandic/automatic/wiki/Troubleshooting) and [debugging](https://github.com/vladmandic/automatic/wiki/Debug) guides +> If you run into issues, check out [troubleshooting](https://vladmandic.github.io/sdnext-docs/Troubleshooting/) and [debugging](https://vladmandic.github.io/sdnext-docs/Debug/) guides > [!TIP] -> All command line options can also be set via env variable +> All command line options can also be set via env variable > For example `--debug` is same as `set SD_DEBUG=true` -## Backend support - -**SD.Next** supports two main backends: *Diffusers* and *Original*: - -- **Diffusers**: Based on new [Huggingface Diffusers](https://huggingface.co/docs/diffusers/index) implementation - Supports *all* models listed below - This backend is set as default for new installations -- **Original**: Based on [LDM](https://github.com/Stability-AI/stablediffusion) reference implementation and significantly expanded on by [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) - This backend and is fully compatible with most existing functionality and extensions written for *A1111 SDWebUI* - Supports **SD 1.x** and **SD 2.x** models - All other model types such as *SD-XL, LCM, Stable Cascade, PixArt, Playground, Segmind, Kandinsky, etc.* require backend **Diffusers** - -### Collab - -- We'd love to have additional maintainers (with comes with full repo rights). If you're interested, ping us! -- In addition to general cross-platform code, desire is to have a lead for each of the main platforms -This should be fully cross-platform, but we'd really love to have additional contributors and/or maintainers to join and help lead the efforts on different platforms - ### Credits -- Main credit goes to [Automatic1111 WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for original codebase +- Main credit goes to [Automatic1111 WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for the original codebase - Additional credits are listed in [Credits](https://github.com/AUTOMATIC1111/stable-diffusion-webui/#credits) - Licenses for modules are listed in [Licenses](html/licenses.html) @@ -154,8 +106,8 @@ This should be fully cross-platform, but we'd really love to have additional con ### Docs -If you're unsure how to use a feature, best place to start is [Wiki](https://github.com/vladmandic/automatic/wiki) and if its not there, -check [ChangeLog](CHANGELOG.md) for when feature was first introduced as it will always have a short note on how to use it +If you're unsure how to use a feature, best place to start is [Docs](https://vladmandic.github.io/sdnext-docs/) and if its not there, +check [ChangeLog](https://vladmandic.github.io/sdnext-docs/CHANGELOG/) for when feature was first introduced as it will always have a short note on how to use it ### Sponsors diff --git a/TODO.md b/TODO.md index 973e062dc..6d89c838f 100644 --- a/TODO.md +++ b/TODO.md @@ -2,21 +2,30 @@ Main ToDo list can be found at [GitHub projects](https://github.com/users/vladmandic/projects) -## Future Candidates +## Pending -- SD35 IPAdapter: -- SD35 LoRA: -- Flux IPAdapter: -- Flux Fill/ControlNet/Redux: -- Flux NF4: -- SANA: +- LoRA direct with caching +- Previewer issues +- Redesign postprocessing -## Other +## Future Candidates +- Flux NF4 loader: - IPAdapter negative: - Control API enhance scripts compatibility +- PixelSmith: -## Workaround in place +## Code TODO -- GGUF -- FlowMatch +- TODO install: python 3.12.4 or higher cause a mess with pydantic +- TODO install: enable ROCm for windows when available +- TODO resize image: enable full VAE mode for resize-latent +- TODO processing: remove duplicate mask params +- TODO flux: fix loader for civitai nf4 models +- TODO model loader: implement model in-memory caching +- TODO hypertile: vae breaks when using non-standard sizes +- TODO model load: force-reloading entire model as loading transformers only leads to massive memory usage +- TODO lora load: direct with bnb +- TODO lora make: support quantized flux +- TODO control: support scripts via api +- TODO modernui: monkey-patch for missing tabs.select event diff --git a/cli/api-model.js b/cli/api-model.js new file mode 100755 index 000000000..e2ce5344a --- /dev/null +++ b/cli/api-model.js @@ -0,0 +1,30 @@ +#!/usr/bin/env node + +const sd_url = process.env.SDAPI_URL || 'http://127.0.0.1:7860'; +const sd_username = process.env.SDAPI_USR; +const sd_password = process.env.SDAPI_PWD; +const models = [ + '/mnt/models/stable-diffusion/sd15/lyriel_v16.safetensors', + '/mnt/models/stable-diffusion/flux/flux-finesse_v2-f1h-fp8.safetensors', + '/mnt/models/stable-diffusion/sdxl/TempestV0.1-Artistic.safetensors', +]; + +async function options(data) { + const method = 'POST'; + const headers = new Headers(); + const body = JSON.stringify(data); + headers.set('Content-Type', 'application/json'); + if (sd_username && sd_password) headers.set({ Authorization: `Basic ${btoa('sd_username:sd_password')}` }); + const res = await fetch(`${sd_url}/sdapi/v1/options`, { method, headers, body }); + return res; +} + +async function main() { + for (const model of models) { + console.log('model:', model); + const res = await options({ sd_model_checkpoint: model }); + console.log('result:', res); + } +} + +main(); diff --git a/cli/api-pulid.js b/cli/api-pulid.js index fde0ae43b..033824e9b 100755 --- a/cli/api-pulid.js +++ b/cli/api-pulid.js @@ -10,12 +10,13 @@ const argparse = require('argparse'); const sd_url = process.env.SDAPI_URL || 'http://127.0.0.1:7860'; const sd_username = process.env.SDAPI_USR; const sd_password = process.env.SDAPI_PWD; +let args = {}; function b64(file) { const data = fs.readFileSync(file); - const b64 = Buffer.from(data).toString('base64'); + const b64str = Buffer.from(data).toString('base64'); const ext = path.extname(file).replace('.', ''); - str = `data:image/${ext};base64,${b64}`; + const str = `data:image/${ext};base64,${b64str}`; // console.log('b64:', ext, b64.length); return str; } @@ -39,7 +40,16 @@ function options() { if (args.pulid) { const b64image = b64(args.pulid); opt.script_name = 'pulid'; - opt.script_args = [b64image, 0.9]; + opt.script_args = [ + b64image, // b64 encoded image, required param + 0.9, // strength, optional + 20, // zero, optional + 'dpmpp_sde', // sampler, optional + 'v2', // ortho, optional + true, // restore (disable pulid after run), optional + true, // offload, optional + 'v1.1', // version, optional + ]; } // console.log('options:', opt); return opt; @@ -53,8 +63,8 @@ function init() { parser.add_argument('--height', { type: 'int', help: 'height' }); parser.add_argument('--pulid', { type: 'str', help: 'pulid init image' }); parser.add_argument('--output', { type: 'str', help: 'output path' }); - const args = parser.parse_args(); - return args + const parsed = parser.parse_args(); + return parsed; } async function main() { @@ -73,12 +83,12 @@ async function main() { console.log('result:', json.info); for (const i in json.images) { // eslint-disable-line guard-for-in const file = args.output || `/tmp/test-${i}.jpg`; - const data = atob(json.images[i]) + const data = atob(json.images[i]); fs.writeFileSync(file, data, 'binary'); console.log('image saved:', file); } } } -const args = init(); +args = init(); main(); diff --git a/cli/full-test.sh b/cli/full-test.sh index e410528ad..912dc3a5b 100755 --- a/cli/full-test.sh +++ b/cli/full-test.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +node cli/api-txt2img.js +node cli/api-pulid.js + source venv/bin/activate echo image-exif python cli/api-info.py --input html/logo-bg-0.jpg diff --git a/cli/load-unet.py b/cli/load-unet.py index 2398cdb64..c910101b0 100644 --- a/cli/load-unet.py +++ b/cli/load-unet.py @@ -33,13 +33,13 @@ def set_module_tensor( stats.dtypes[value.dtype] = 0 stats.dtypes[value.dtype] += 1 if name in module._buffers: # pylint: disable=protected-access - module._buffers[name] = value.to(device=device, dtype=dtype, non_blocking=True) # pylint: disable=protected-access + module._buffers[name] = value.to(device=device, dtype=dtype) # pylint: disable=protected-access if 'buffers' not in stats.weights: stats.weights['buffers'] = 0 stats.weights['buffers'] += 1 elif value is not None: param_cls = type(module._parameters[name]) # pylint: disable=protected-access - module._parameters[name] = param_cls(value, requires_grad=old_value.requires_grad).to(device, dtype=dtype, non_blocking=True) # pylint: disable=protected-access + module._parameters[name] = param_cls(value, requires_grad=old_value.requires_grad).to(device, dtype=dtype) # pylint: disable=protected-access if 'parameters' not in stats.weights: stats.weights['parameters'] = 0 stats.weights['parameters'] += 1 diff --git a/configs/flux/vae/config.json b/configs/flux/vae/config.json index b43183d0f..7ecb342c2 100644 --- a/configs/flux/vae/config.json +++ b/configs/flux/vae/config.json @@ -14,7 +14,7 @@ "DownEncoderBlock2D", "DownEncoderBlock2D" ], - "force_upcast": true, + "force_upcast": false, "in_channels": 3, "latent_channels": 16, "latents_mean": null, diff --git a/configs/sd15/vae/config.json b/configs/sd15/vae/config.json index 55d78924f..2cba0e824 100644 --- a/configs/sd15/vae/config.json +++ b/configs/sd15/vae/config.json @@ -14,6 +14,7 @@ "DownEncoderBlock2D", "DownEncoderBlock2D" ], + "force_upcast": false, "in_channels": 3, "latent_channels": 4, "layers_per_block": 2, diff --git a/configs/sd3/vae/config.json b/configs/sd3/vae/config.json index 58e7764fb..f6f4e8684 100644 --- a/configs/sd3/vae/config.json +++ b/configs/sd3/vae/config.json @@ -15,7 +15,7 @@ "DownEncoderBlock2D", "DownEncoderBlock2D" ], - "force_upcast": true, + "force_upcast": false, "in_channels": 3, "latent_channels": 16, "latents_mean": null, diff --git a/configs/sdxl/vae/config.json b/configs/sdxl/vae/config.json index a66a171ba..1c7a60866 100644 --- a/configs/sdxl/vae/config.json +++ b/configs/sdxl/vae/config.json @@ -15,7 +15,7 @@ "DownEncoderBlock2D", "DownEncoderBlock2D" ], - "force_upcast": true, + "force_upcast": false, "in_channels": 3, "latent_channels": 4, "layers_per_block": 2, diff --git a/extensions-builtin/Lora/lora_extract.py b/extensions-builtin/Lora/lora_extract.py index c2e0a275b..1d92f3c6e 100644 --- a/extensions-builtin/Lora/lora_extract.py +++ b/extensions-builtin/Lora/lora_extract.py @@ -182,19 +182,6 @@ def make_lora(fn, maxrank, auto_rank, rank_ratio, modules, overwrite): progress.remove_task(task) t3 = time.time() - # TODO: Handle quant for Flux - # if 'te' in modules and getattr(shared.sd_model, 'transformer', None) is not None: - # for name, module in shared.sd_model.transformer.named_modules(): - # if "norm" in name and "linear" not in name: - # continue - # weights_backup = getattr(module, "network_weights_backup", None) - # if weights_backup is None: - # continue - # module.svdhandler = SVDHandler() - # module.svdhandler.network_name = "lora_transformer_" + name.replace(".", "_") - # module.svdhandler.decompose(module.weight, weights_backup) - # module.svdhandler.findrank(rank, rank_ratio) - lora_state_dict = {} for sub in ['text_encoder', 'text_encoder_2', 'unet', 'transformer']: submodel = getattr(shared.sd_model, sub, None) diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 5e6eaef6c..a410a8e3b 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -22,7 +22,6 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.dim = weights.w["lora_down.weight"].shape[0] def create_module(self, weights, key, none_ok=False): - from modules.shared import opts weight = weights.get(key) if weight is None and none_ok: return None @@ -32,7 +31,7 @@ def create_module(self, weights, key, none_ok=False): if is_linear: weight = weight.reshape(weight.shape[0], -1) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif is_conv and key == "lora_down.weight" or key == "dyn_up": + elif is_conv and (key == "lora_down.weight" or key == "dyn_up"): if len(weight.shape) == 2: weight = weight.reshape(weight.shape[0], -1, 1, 1) if weight.shape[2] != 1 or weight.shape[3] != 1: @@ -41,7 +40,7 @@ def create_module(self, weights, key, none_ok=False): module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) elif is_conv and key == "lora_mid.weight": module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) - elif is_conv and key == "lora_up.weight" or key == "dyn_down": + elif is_conv and (key == "lora_up.weight" or key == "dyn_down"): module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) else: raise AssertionError(f'Lora unsupported: layer={self.network_key} type={type(self.sd_module).__name__}') @@ -49,8 +48,6 @@ def create_module(self, weights, key, none_ok=False): if weight.shape != module.weight.shape: weight = weight.reshape(module.weight.shape) module.weight.copy_(weight) - if opts.lora_load_gpu: - module = module.to(device=devices.device, dtype=devices.dtype) module.weight.requires_grad_(False) return module diff --git a/extensions-builtin/Lora/network_overrides.py b/extensions-builtin/Lora/network_overrides.py index 5334f3c1b..b5c28b718 100644 --- a/extensions-builtin/Lora/network_overrides.py +++ b/extensions-builtin/Lora/network_overrides.py @@ -26,7 +26,6 @@ force_models = [ # forced always 'sc', - # 'sd3', 'kandinsky', 'hunyuandit', 'auraflow', diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index db617ee5b..fd6287c62 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -88,7 +88,7 @@ def assign_network_names_to_compvis_modules(sd_model): network_name = name.replace(".", "_") network_layer_mapping[network_name] = module module.network_layer_name = network_name - shared.sd_model.network_layer_mapping = network_layer_mapping + sd_model.network_layer_mapping = network_layer_mapping def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_default_multiplier) -> network.Network: @@ -141,7 +141,7 @@ def load_network(name, network_on_disk) -> network.Network: sd = sd_models.read_state_dict(network_on_disk.filename, what='network') if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict sd = lora_convert._convert_kohya_flux_lora_to_diffusers(sd) or sd # pylint: disable=protected-access - assign_network_names_to_compvis_modules(shared.sd_model) # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 + assign_network_names_to_compvis_modules(shared.sd_model) keys_failed_to_match = {} matched_networks = {} bundle_embeddings = {} diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index ffbef47d9..24723dd7f 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -5,7 +5,7 @@ from network import NetworkOnDisk from ui_extra_networks_lora import ExtraNetworksPageLora from extra_networks_lora import ExtraNetworkLora -from modules import script_callbacks, extra_networks, ui_extra_networks, ui_models # pylint: disable=unused-import +from modules import script_callbacks, extra_networks, ui_extra_networks, ui_models, shared # pylint: disable=unused-import re_lora = re.compile(" 3: # TODO python 3.12.4 or higher cause a mess with pydantic + if int(sys.version_info.major) == 3 and int(sys.version_info.minor) == 12 and int(sys.version_info.micro) > 3: # TODO install: python 3.12.4 or higher cause a mess with pydantic log.error(f"Python version incompatible: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} required 3.12.3 or lower") if reason is not None: log.error(reason) @@ -457,9 +457,9 @@ def check_python(supported_minors=[9, 10, 11, 12], reason=None): # check diffusers version def check_diffusers(): - if args.skip_all or args.skip_requirements: + if args.skip_all or args.skip_git: return - sha = 'b5fd6f13f5434d69d919cc8cedf0b11db664cf06' + sha = '6dfaec348780c6153a4cfd03a01972a291d67f82' # diffusers commit hash pkg = pkg_resources.working_set.by_key.get('diffusers', None) minor = int(pkg.version.split('.')[1] if pkg is not None else 0) cur = opts.get('diffusers_version', '') if minor > 0 else '' @@ -483,6 +483,7 @@ def check_onnx(): def check_torchao(): + """ if args.skip_all or args.skip_requirements: return if installed('torchao', quiet=True): @@ -492,6 +493,8 @@ def check_torchao(): pip('uninstall --yes torchao', ignore=True, quiet=True, uv=False) for m in [m for m in sys.modules if m.startswith('torchao')]: del sys.modules[m] + """ + return def install_cuda(): @@ -549,7 +552,7 @@ def install_rocm_zluda(): log.info(msg) torch_command = '' if sys.platform == "win32": - # TODO after ROCm for Windows is released + # TODO install: enable ROCm for windows when available if args.device_id is not None: if os.environ.get('HIP_VISIBLE_DEVICES', None) is not None: @@ -634,25 +637,30 @@ def install_ipex(torch_command): os.environ.setdefault('NEOReadDebugKeys', '1') if os.environ.get("ClDeviceGlobalMemSizeAvailablePercent", None) is None: os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100') + if os.environ.get("PYTORCH_ENABLE_XPU_FALLBACK", None) is None: + os.environ.setdefault('PYTORCH_ENABLE_XPU_FALLBACK', '1') if "linux" in sys.platform: - torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/') + torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.5.1+cxx11.abi torchvision==0.20.1+cxx11.abi intel-extension-for-pytorch==2.5.10+xpu oneccl_bind_pt==2.5.0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/') # torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision --index-url https://download.pytorch.org/whl/test/xpu') # test wheels are stable previews, significantly slower than IPEX # os.environ.setdefault('TENSORFLOW_PACKAGE', 'tensorflow==2.15.1 intel-extension-for-tensorflow[xpu]==2.15.0.1') else: torch_command = os.environ.get('TORCH_COMMAND', '--pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/xpu') # torchvision doesn't exist on test/stable branch for windows - install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2024.3.0'), 'openvino', ignore=True) - install('nncf==2.7.0', 'nncf', ignore=True) + install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2024.5.0'), 'openvino', ignore=True) + install('nncf==2.7.0', ignore=True, no_deps=True) # requires older pandas install(os.environ.get('ONNXRUNTIME_PACKAGE', 'onnxruntime-openvino'), 'onnxruntime-openvino', ignore=True) return torch_command def install_openvino(torch_command): - check_python(supported_minors=[8, 9, 10, 11, 12], reason='OpenVINO backend requires Python 3.9, 3.10 or 3.11') + check_python(supported_minors=[9, 10, 11, 12], reason='OpenVINO backend requires Python 3.9, 3.10 or 3.11') log.info('OpenVINO: selected') - torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.3.1+cpu torchvision==0.18.1+cpu --index-url https://download.pytorch.org/whl/cpu') - install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2024.3.0'), 'openvino') + if sys.platform == 'darwin': + torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.3.1 torchvision==0.18.1') + else: + torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.3.1+cpu torchvision==0.18.1+cpu --index-url https://download.pytorch.org/whl/cpu') + install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2024.6.0'), 'openvino') install(os.environ.get('ONNXRUNTIME_PACKAGE', 'onnxruntime-openvino'), 'onnxruntime-openvino', ignore=True) - install('nncf==2.12.0', 'nncf') + install('nncf==2.14.1', 'nncf') os.environ.setdefault('PYTORCH_TRACING_MODE', 'TORCHFX') if os.environ.get("NEOReadDebugKeys", None) is None: os.environ.setdefault('NEOReadDebugKeys', '1') @@ -682,7 +690,9 @@ def install_torch_addons(): if opts.get('nncf_compress_weights', False) and not args.use_openvino: install('nncf==2.7.0', 'nncf') if opts.get('optimum_quanto_weights', False): - install('optimum-quanto', 'optimum-quanto') + install('optimum-quanto==0.2.6', 'optimum-quanto') + if not args.experimental: + uninstall('wandb', quiet=True) if triton_command is not None: install(triton_command, 'triton', quiet=True) @@ -727,8 +737,6 @@ def check_torch(): torch_command = install_rocm_zluda() elif is_ipex_available: torch_command = install_ipex(torch_command) - elif allow_openvino: - torch_command = install_openvino(torch_command) else: machine = platform.machine() @@ -746,6 +754,8 @@ def check_torch(): log.warning('Torch: CPU-only version installed') torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision') if 'torch' in torch_command and not args.version: + if not installed('torch'): + log.info(f'Torch: download and install in progress... cmd="{torch_command}"') install(torch_command, 'torch torchvision', quiet=True) else: try: @@ -999,8 +1009,8 @@ def install_optional(): install('basicsr') install('gfpgan') install('clean-fid') - install('optimum-quanto', ignore=True) - install('bitsandbytes', ignore=True) + install('optimum-quanto=0.2.6', ignore=True) + install('bitsandbytes==0.45.0', ignore=True) install('pynvml', ignore=True) install('ultralytics==8.3.40', ignore=True) install('Cython', ignore=True) @@ -1174,9 +1184,15 @@ def same(ver): def check_venv(): + def try_relpath(p): + try: + return os.path.relpath(p) + except ValueError: + return p + import site - pkg_path = [os.path.relpath(p) for p in site.getsitepackages() if os.path.exists(p)] - log.debug(f'Packages: venv={os.path.relpath(sys.prefix)} site={pkg_path}') + pkg_path = [try_relpath(p) for p in site.getsitepackages() if os.path.exists(p)] + log.debug(f'Packages: venv={try_relpath(sys.prefix)} site={pkg_path}') for p in pkg_path: invalid = [] for f in os.listdir(p): diff --git a/javascript/base.css b/javascript/base.css index 7daa8b2bd..a30a71845 100644 --- a/javascript/base.css +++ b/javascript/base.css @@ -25,7 +25,6 @@ .progressDiv .progress { width: 0%; height: 20px; background: #0060df; color: white; font-weight: bold; line-height: 20px; padding: 0 8px 0 0; text-align: right; overflow: visible; white-space: nowrap; padding: 0 0.5em; } .livePreview { position: absolute; z-index: 50; background-color: transparent; width: -moz-available; width: -webkit-fill-available; } .livePreview img { position: absolute; object-fit: contain; width: 100%; height: 100%; } -.dark .livePreview { background-color: rgb(17 24 39 / var(--tw-bg-opacity)); } .popup-metadata { color: white; background: #0000; display: inline-block; white-space: pre-wrap; font-size: 0.75em; } /* fullpage image viewer */ @@ -80,7 +79,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt /* extra networks */ .extra-networks > div { margin: 0; border-bottom: none !important; } -.extra-networks .second-line { display: flex; width: -moz-available; width: -webkit-fill-available; gap: 0.3em; box-shadow: var(--input-shadow); } +.extra-networks .second-line { display: flex; width: -moz-available; width: -webkit-fill-available; gap: 0.3em; box-shadow: var(--input-shadow); margin-bottom: 2px; } .extra-networks .search { flex: 1; } .extra-networks .description { flex: 3; } .extra-networks .tab-nav > button { margin-right: 0; height: 24px; padding: 2px 4px 2px 4px; } @@ -89,7 +88,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt .extra-networks .custom-button { width: 120px; width: 100%; background: none; justify-content: left; text-align: left; padding: 2px 8px 2px 16px; text-indent: -8px; box-shadow: none; line-break: auto; } .extra-networks .custom-button:hover { background: var(--button-primary-background-fill) } .extra-networks-tab { padding: 0 !important; } -.extra-network-subdirs { background: var(--input-background-fill); overflow-x: hidden; overflow-y: auto; min-width: max(15%, 120px); padding-top: 0.5em; margin-top: -4px !important; } +.extra-network-subdirs { background: var(--input-background-fill); overflow-x: hidden; overflow-y: auto; min-width: max(15%, 120px); padding-top: 0.5em; } .extra-networks-page { display: flex } .extra-network-cards { display: flex; flex-wrap: wrap; overflow-y: auto; overflow-x: hidden; align-content: flex-start; width: -moz-available; width: -webkit-fill-available; } .extra-network-cards .card { height: fit-content; margin: 0 0 0.5em 0.5em; position: relative; scroll-snap-align: start; scroll-margin-top: 0; } diff --git a/javascript/black-teal-reimagined.css b/javascript/black-teal-reimagined.css new file mode 100644 index 000000000..28176d247 --- /dev/null +++ b/javascript/black-teal-reimagined.css @@ -0,0 +1,1193 @@ +/* Generic HTML Tags */ +@font-face { + font-family: 'NotoSans'; + font-display: swap; + font-style: normal; + font-weight: 100; + src: local('NotoSans'), url('notosans-nerdfont-regular.ttf'); +} + +html { + scroll-behavior: smooth; +} + +:root, +.light, +.dark { + --font: 'NotoSans'; + --font-mono: 'ui-monospace', 'Consolas', monospace; + --font-size: 16px; + + /* Primary Colors */ + --primary-50: #7dffff; + --primary-100: #72e8e8; + --primary-200: #67d2d2; + --primary-300: #5dbcbc; + --primary-400: #52a7a7; + --primary-500: #489292; + --primary-600: #3e7d7d; + --primary-700: #356969; + --primary-800: #2b5656; + --primary-900: #224444; + --primary-950: #193232; + + /* Neutral Colors */ + --neutral-50: #f0f0f0; + --neutral-100: #e0e0e0; + --neutral-200: #d0d0d0; + --neutral-300: #b0b0b0; + --neutral-400: #909090; + --neutral-500: #707070; + --neutral-600: #606060; + --neutral-700: #404040; + --neutral-800: #303030; + --neutral-900: #202020; + --neutral-950: #101010; + + /* Highlight and Inactive Colors */ + --highlight-color: var(--primary-200); + --inactive-color: var(--primary-800); + + /* Text Colors */ + --body-text-color: var(--neutral-100); + --body-text-color-subdued: var(--neutral-300); + + /* Background Colors */ + --background-color: var(--neutral-950); + --background-fill-primary: var(--neutral-800); + --input-background-fill: var(--neutral-900); + + /* Padding and Borders */ + --input-padding: 4px; + --input-shadow: none; + --button-primary-text-color: var(--neutral-100); + --button-primary-background-fill: var(--primary-600); + --button-primary-background-fill-hover: var(--primary-800); + --button-secondary-text-color: var(--neutral-100); + --button-secondary-background-fill: var(--neutral-900); + --button-secondary-background-fill-hover: var(--neutral-600); + + /* Border Radius */ + --radius-xs: 2px; + --radius-sm: 4px; + --radius-md: 6px; + --radius-lg: 8px; + --radius-xl: 10px; + --radius-xxl: 15px; + --radius-xxxl: 20px; + + /* Shadows */ + --shadow-sm: 0 1px 2px rgba(0, 0, 0, 0.1); + --shadow-md: 0 2px 4px rgba(0, 0, 0, 0.1); + --shadow-lg: 0 4px 8px rgba(0, 0, 0, 0.1); + --shadow-xl: 0 8px 16px rgba(0, 0, 0, 0.1); + + /* Animation */ + --transition: all 0.3s ease; + + /* Scrollbar */ + --scrollbar-bg: var(--neutral-800); + --scrollbar-thumb: var(--highlight-color); +} + +html { + font-size: var(--font-size); + font-family: var(--font); +} + +body, +button, +input, +select, +textarea { + font-family: var(--font); + color: var(--body-text-color); + transition: var(--transition); +} + +button { + max-width: 400px; + white-space: nowrap; + padding: 8px 12px; + border: none; + border-radius: var(--radius-md); + background-color: var(--button-primary-background-fill); + color: var(--button-primary-text-color); + cursor: pointer; + box-shadow: var(--shadow-sm); + transition: transform 0.2s ease, background-color 0.3s ease; +} + +button:hover { + background-color: var(--button-primary-background-fill-hover); + transform: scale(1.05); +} + +/* Range Input Styles */ +.slider-container { + width: 100%; + /* Ensures the container takes full width */ + max-width: 100%; + /* Prevents overflow */ + padding: 0 10px; + /* Adds padding for aesthetic spacing */ + box-sizing: border-box; + /* Ensures padding doesn't affect width */ +} + +input[type='range'] { + display: block; + margin: 0; + padding: 0; + height: 1em; + background-color: transparent; + overflow: hidden; + cursor: pointer; + box-shadow: none; + -webkit-appearance: none; + opacity: 0.7; + appearance: none; + width: 100%; + /* Makes the slider responsive */ +} + +input[type='range'] { + opacity: 1; +} + +input[type='range']::-webkit-slider-thumb { + -webkit-appearance: none; + height: 1em; + width: 1em; + background-color: var(--highlight-color); + border-radius: var(--radius-xs); + box-shadow: var(--shadow-md); + cursor: pointer; + /* Ensures the thumb is clickable */ +} + +input[type='range']::-webkit-slider-runnable-track { + -webkit-appearance: none; + height: 6px; + background: var(--input-background-fill); + border-radius: var(--radius-md); +} + +input[type='range']::-moz-range-thumb { + height: 1em; + width: 1em; + background-color: var(--highlight-color); + border-radius: var(--radius-xs); + box-shadow: var(--shadow-md); + cursor: pointer; + /* Ensures the thumb is clickable */ +} + +input[type='range']::-moz-range-track { + height: 6px; + background: var(--input-background-fill); + border-radius: var(--radius-md); +} + +@media (max-width: 768px) { + .slider-container { + width: 100%; + /* Adjust width for smaller screens */ + } + + .networks-menu, + .styles-menu { + width: 100%; + /* Ensure menus are full width */ + margin: 0; + /* Reset margins for smaller screens */ + } +} + +/* Scrollbar Styles */ +:root { + scrollbar-color: var(--scrollbar-thumb) var(--scrollbar-bg); +} + +::-webkit-scrollbar { + width: 12px; + height: 12px; +} + +::-webkit-scrollbar-track { + background: var(--scrollbar-bg); + border-radius: var(--radius-lg); +} + +::-webkit-scrollbar-thumb { + background-color: var(--scrollbar-thumb); + border-radius: var(--radius-lg); + box-shadow: var(--shadow-sm); +} + +/* Tab Navigation Styles */ +.tab-nav { + display: flex; + /* Use flexbox for layout */ + justify-content: space-evenly; + /* Space out the tabs evenly */ + align-items: center; + /* Center items vertically */ + background: var(--background-color); + /* Background color */ + border-bottom: 1px dashed var(--highlight-color) !important; + /* Bottom border for separation */ + box-shadow: var(--shadow-md); + /* Shadow for depth */ + margin-bottom: 5px; + /* Add some space between the tab nav and the content */ + padding-bottom: 5px; + /* Add space between buttons and border */ +} + +/* Individual Tab Styles */ +.tab-nav>button { + background: var(--neutral-900); + /* No background for default state */ + color: var(--text-color); + /* Text color */ + border: 1px solid var(--highlight-color); + /* No border */ + border-radius: var(--radius-xxl); + /* Rounded corners */ + cursor: pointer; + /* Pointer cursor */ + transition: background 0.3s ease, color 0.3s ease; + /* Smooth transition */ + padding-top: 5px; + padding-bottom: 5px; + padding-right: 10px; + padding-left: 10px; + margin-bottom: 3px; +} + +/* Active Tab Style */ +.tab-nav>button.selected { + background: var(--primary-100); + /* Highlight active tab */ + color: var(--background-color); + /* Change text color for active tab */ +} + +/* Hover State for Tabs */ +.tab-nav>button:hover { + background: var(--highlight-color); + /* Background on hover */ + color: var(--background-color); + /* Change text color on hover */ +} + +/* Responsive Styles */ +@media (max-width: 768px) { + .tab-nav { + flex-direction: column; + /* Stack tabs vertically on smaller screens */ + align-items: stretch; + /* Stretch tabs to full width */ + } + + .tab-nav>button { + width: 100%; + /* Full width for buttons */ + text-align: left; + /* Align text to the left */ + } +} + +/* Quick Settings Panel Styles */ +#quicksettings { + background: var(--background-color); + /* Background color */ + box-shadow: var(--shadow-lg); + /* Shadow for depth */ + border-radius: var(--radius-lg); + /* Rounded corners */ + padding: 1em; + /* Padding for spacing */ + z-index: 200; + /* Ensure it stays on top */ +} + +/* Quick Settings Header */ +#quicksettings .header { + font-size: var(--text-lg); + /* Font size for header */ + font-weight: bold; + /* Bold text */ + margin-bottom: 0.5em; + /* Space below header */ +} + +/* Quick Settings Options */ +#quicksettings .option { + display: flex; + /* Flexbox for layout */ + justify-content: space-between; + /* Space between label and toggle */ + align-items: center; + /* Center items vertically */ + padding: 0.5em 0; + /* Padding for each option */ + border-bottom: 1px solid var(--neutral-600); + /* Separator line */ +} + +/* Option Label Styles */ +#quicksettings .option label { + color: var(--text-color); + /* Text color */ +} + +/* Toggle Switch Styles */ +#quicksettings .option input[type="checkbox"] { + cursor: pointer; + /* Pointer cursor */ +} + +/* Quick Settings Footer */ +#quicksettings .footer { + margin-top: 1em; + /* Space above footer */ + text-align: right; + /* Align text to the right */ +} + +/* Close Button Styles */ +#quicksettings .footer button { + background: var(--button-primary-background-fill); + /* Button background */ + color: var(--button-primary-text-color); + /* Button text color */ + border: none; + /* No border */ + border-radius: var(--radius-md); + /* Rounded corners */ + padding: 0.5em 1em; + /* Padding for button */ + cursor: pointer; + /* Pointer cursor */ + transition: 0.3s ease; + /* Smooth transition */ +} + +/* Close Button Hover State */ +#quicksettings .footer button:hover { + background: var(--highlight-color); + /* Change background on hover */ +} + +/* Responsive Styles */ +@media (max-width: 768px) { + #quicksettings { + right: 10px; + /* Adjust position for smaller screens */ + width: 90%; + /* Full width on smaller screens */ + } +} + +/* Form Styles */ +div.form, #txt2img_seed_row, #txt2img_subseed_row { + border-width: 0; + box-shadow: var(--shadow-md); + background: var(--background-fill-primary); + border-bottom: 3px solid var(--highlight-color); + padding: 3px; + border-radius: var(--radius-lg); + margin: 1px; +} + +/* Image preview styling*/ +#txt2img_gallery { + background: var(--background-fill-primary); + padding: 5px; + margin: 0px; +} + +@keyframes colorChange { + 0% { + background-color: var(--neutral-800); + } + 50% { + background-color: var(--neutral-700); + } + 100% { + background-color: var(--neutral-800); + } +} + +.livePreview { + animation: colorChange 3s ease-in-out infinite; /* Adjust the duration as needed */ + padding: 5px; +} + +/* Gradio Style Classes */ +fieldset .gr-block.gr-box, +label.block span { + padding: 0; + margin-top: -4px; +} + +.border-2 { + border-width: 0; +} + +.border-b-2 { + border-bottom-width: 2px; + border-color: var(--highlight-color) !important; + padding-bottom: 2px; + margin-bottom: 8px; +} + +.bg-white { + color: lightyellow; + background-color: var(--inactive-color); +} + +.gr-box { + border-radius: var(--radius-sm) !important; + background-color: var(--neutral-950) !important; + box-shadow: var(--shadow-md); + border-width: 0; + padding: 4px; + margin: 12px 0; +} + +.gr-button { + font-weight: normal; + box-shadow: var(--shadow-sm); + font-size: 0.8rem; + min-width: 32px; + min-height: 32px; + padding: 3px; + margin: 3px; + transition: var(--transition); +} + +.gr-button:hover { + background-color: var(--highlight-color); +} + +.gr-check-radio { + background-color: var(--inactive-color); + border-width: 0; + border-radius: var(--radius-lg); + box-shadow: var(--shadow-sm); +} + +.gr-check-radio:checked { + background-color: var(--highlight-color); +} + +.gr-compact { + background-color: var(--background-color); +} + +.gr-form { + border-width: 0; +} + +.gr-input { + background-color: var(--neutral-800) !important; + padding: 4px; + margin: 4px; + border-radius: var(--radius-md); + transition: var(--transition); +} + +.gr-input:hover { + background-color: var(--neutral-700); +} + +.gr-input-label { + color: lightyellow; + border-width: 0; + background: transparent; + padding: 2px !important; +} + +.gr-panel { + background-color: var(--background-color); + border-radius: var(--radius-md); + box-shadow: var(--shadow-md); +} + +.eta-bar { + display: none !important; +} + +.gradio-slider { + max-width: 200px; +} + +.gradio-slider input[type="number"] { + background: var(--neutral-950); + margin-top: 2px; +} + +.gradio-image { + height: unset !important; +} + +svg.feather.feather-image, +.feather .feather-image { + display: none; +} + +.gap-2 { + padding-top: 8px; +} + +.gr-box>div>div>input.gr-text-input { + right: 0; + width: 4em; + padding: 0; + top: -12px; + border: none; + max-height: 20px; +} + +.output-html { + line-height: 1.2 rem; + overflow-x: hidden; +} + +.output-html>div { + margin-bottom: 8px; +} + +.overflow-hidden .flex .flex-col .relative col .gap-4 { + min-width: var(--left-column); + max-width: var(--left-column); +} + +.p-2 { + padding: 0; +} + +.px-4 { + padding-left: 1rem; + padding-right: 1rem; +} + +.py-6 { + padding-bottom: 0; +} + +.tabs { + background-color: var(--background-color); +} + +.block.token-counter span { + background-color: var(--input-background-fill) !important; + box-shadow: 2px 2px 2px #111; + border: none !important; + font-size: 0.7rem; +} + +.label-wrap { + margin: 8px 0px 4px 0px; +} + +.gradio-button.tool { + border: none; + background: none; + box-shadow: none; + filter: hue-rotate(340deg) saturate(0.5); +} + +#tab_extensions table td, +#tab_extensions table th, +#tab_config table td, +#tab_config table th { + border: none; +} + +#tab_extensions table tr:hover, +#tab_config table tr:hover { + background-color: var(--neutral-500) !important; +} + +#tab_extensions table, +#tab_config table { + width: 96vw; +} + +#tab_extensions table thead, +#tab_config table thead { + background-color: var(--neutral-700); +} + +#tab_extensions table, +#tab_config table { + background-color: var(--neutral-900); +} + +/* Automatic Style Classes */ +.progressDiv { + border-radius: var(--radius-sm) !important; + position: fixed; + top: 44px; + right: 26px; + max-width: 262px; + height: 48px; + z-index: 99; + box-shadow: var(--button-shadow); +} + +.progressDiv .progress { + border-radius: var(--radius-lg) !important; + background: var(--highlight-color); + line-height: 3rem; + height: 48px; +} + +.gallery-item { + box-shadow: none !important; +} + +.performance { + color: #888; +} + +.image-buttons { + justify-content: center; + gap: 0 !important; +} + +.image-buttons>button { + max-width: 160px; +} + +.tooltip { + background: var(--primary-300); + color: black; + border: none; + border-radius: var(--radius-lg); +} + +#system_row>button, +#settings_row>button, +#config_row>button { + max-width: 10em; +} + +/* Gradio Elements Overrides */ +#div.gradio-container { + overflow-x: hidden; +} + +#img2img_label_copy_to_img2img { + font-weight: normal; +} + +#txt2img_styles, +#img2img_styles, +#control_styles { + padding: 0; + margin-top: 2px; +} + +#txt2img_styles_refresh, +#img2img_styles_refresh, +#control_styles_refresh { + padding: 0; + margin-top: 1em; +} + +#img2img_settings { + min-width: calc(2 * var(--left-column)); + max-width: calc(2 * var(--left-column)); + background-color: var(--neutral-950); + padding-top: 16px; +} + +#interrogate, +#deepbooru { + margin: 0 0px 10px 0px; + max-width: 80px; + max-height: 80px; + font-weight: normal; + font-size: 0.95em; +} + +#quicksettings .gr-button-tool { + font-size: 1.6rem; + box-shadow: none; + margin-left: -20px; + margin-top: -2px; + height: 2.4em; +} + +#save-animation { + border-radius: var(--radius-sm) !important; + margin-bottom: 16px; + background-color: var(--neutral-950); +} + +#script_list { + padding: 4px; + margin-top: 16px; + margin-bottom: 8px; +} + +#settings>div.flex-wrap { + width: 15em; +} + +#txt2img_cfg_scale { + min-width: 200px; +} + +#txt2img_checkboxes, +#img2img_checkboxes, +#control_checkboxes { + background-color: transparent; + margin-bottom: 0.2em; +} + +#extras_upscale { + margin-top: 10px; +} + +#txt2img_progress_row>div { + min-width: var(--left-column); + max-width: var(--left-column); +} + +#txt2img_settings { + min-width: var(--left-column); + max-width: var(--left-column); + background-color: var(--neutral-950); +} + +#pnginfo_html2_info { + margin-top: -18px; + background-color: var(--input-background-fill); + padding: var(--input-padding); +} + +#txt2img_styles_row, +#img2img_styles_row, +#control_styles_row { + margin-top: -6px; +} + +.block>span { + margin-bottom: 0 !important; + margin-top: var(--spacing-lg); +} + +/* Extra Networks Container */ +#extra_networks_root { + z-index: 100; + background: var(--background-color); + box-shadow: var(--shadow-md); + border-radius: var(--radius-lg); + transform: translateX(100%); + animation: slideIn 0.5s forwards; + overflow: hidden; + /* Prevents overflow of content */ +} + +@keyframes slideIn { + to { + transform: translateX(0); + } +} + +/* Extra Networks Styles */ +.extra-networks { + border-left: 2px solid var(--highlight-color) !important; + padding-left: 4px; +} + +.extra-networks .tab-nav>button:hover { + background: var(--highlight-color); +} + +/* Network tab search and description important fix, dont remove */ +#txt2img_description, +#txt2img_extra_search, +#img2img_description, +#img2img_extra_search, +#control_description, +#control_extra_search { + margin-top: 50px; +} + +.extra-networks .buttons>button:hover { + background: var(--highlight-color); +} + +/* Network Cards Container */ +.extra-network-cards { + display: flex; + flex-wrap: wrap; + overflow-y: auto; + overflow-x: hidden; + align-content: flex-start; + padding-top: 20px; + justify-content: center; + width: 100%; + /* Ensures it takes full width */ +} + +/* Individual Card Styles */ +.extra-network-cards .card { + height: fit-content; + margin: 0 0 0.5em 0.5em; + position: relative; + scroll-snap-align: start; + scroll-margin-top: 0; + background: var(--neutral-800); + /* Background for cards */ + border-radius: var(--radius-md); + box-shadow: var(--shadow-md); + transition: var(--transition); +} + +/* Overlay Styles */ +.extra-network-cards .card .overlay { + z-index: 10; + width: 100%; + background: none; + border-radius: var(--radius-md); +} + +/* Overlay Name Styles */ +.extra-network-cards .card .overlay .name { + font-size: var(--text-lg); + font-weight: bold; + text-shadow: 1px 1px black; + color: white; + overflow-wrap: anywhere; + position: absolute; + bottom: 0; + padding: 0.2em; + z-index: 10; +} + +/* Preview Styles */ +.extra-network-cards .card .preview { + box-shadow: var(--button-shadow); + min-height: 30px; + border-radius: var(--radius-md); + z-index: 9999; +} + +/* Hover Effects */ +.extra-network-cards .card:hover { + transform: scale(1.3); + z-index: 9999; /* Use a high value to ensure it appears on top */ + transition: transform 0.3s ease, z-index 0s; /* Smooth transition */ +} + +.extra-network-cards .card:hover .overlay { + z-index: 10000; /* Ensure overlay is also on top */ +} + +.extra-network-cards .card:hover .preview { + box-shadow: none; + filter: grayscale(0%); +} + +/* Tags Styles */ +.extra-network-cards .card .overlay .tags { + display: none; + overflow-wrap: anywhere; + position: absolute; + top: 100%; + z-index: 20; + background: var(--body-background-fill); + overflow-x: hidden; + overflow-y: auto; + max-height: 333px; +} + +/* Individual Tag Styles */ +.extra-network-cards .card .overlay .tag { + padding: 2px; + margin: 2px; + background: rgba(70, 70, 70, 0.60); + font-size: var(--text-md); + cursor: pointer; + display: inline-block; +} + +/* Actions Styles */ +.extra-network-cards .card .actions>span { + padding: 4px; + font-size: 34px !important; +} + +.extra-network-cards .card .actions { + background: none; +} + +.extra-network-cards .card .actions .details { + bottom: 50px; + background-color: var(--neutral-800); +} + +.extra-network-cards .card .actions>span:hover { + color: var(--highlight-color); +} + +/* Version Styles */ +.extra-network-cards .card .version { + position: absolute; + top: 0; + left: 0; + padding: 2px; + font-weight: bolder; + text-shadow: 1px 1px black; + text-transform: uppercase; + background: gray; + opacity: 75%; + margin: 4px; + line-height: 0.9rem; +} + +/* Hover Actions */ +.extra-network-cards .card:hover .actions { + display: block; +} + +.extra-network-cards .card:hover .overlay .tags { + display: block; +} + +/* No Preview Card Styles */ +.extra-network-cards .card:has(>img[src*="card-no-preview.png"])::before { + content: ''; + position: absolute; + width: 100%; + height: 100%; + mix-blend-mode: multiply; + background-color: var(--data-color); +} + +/* Card List Styles */ +.extra-network-cards .card-list { + display: flex; + margin: 0.3em; + padding: 0.3em; + background: var(--input-background-fill); + cursor: pointer; + border-radius: var(--button-large-radius); +} + +.extra-network-cards .card-list .tag { + color: var(--primary-500); + margin-left: 0.8em; +} + +/* Correction color picker styling */ +#txt2img_hdr_color_picker label input { + width: 100%; + height: 100%; +} + +/* loader */ +.splash { + position: fixed; + top: 0; + left: 0; + width: 100vw; + height: 100vh; + z-index: 1000; + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + background-color: rgba(0, 0, 0, 0.8); +} + +.motd { + margin-top: 1em; + color: var(--body-text-color-subdued); + font-family: monospace; + font-variant: all-petite-caps; + font-size: 1.2em; +} + +.splash-img { + margin: 0; + width: 512px; + height: 512px; + background-repeat: no-repeat; + animation: color 8s infinite alternate, move 3s infinite alternate; +} + +.loading { + color: white; + position: border-box; + top: 85%; + font-size: 1.5em; +} + +.loader { + width: 100px; + height: 100px; + border: var(--spacing-md) solid transparent; + border-radius: 50%; + border-top: var(--spacing-md) solid var(--primary-600); + animation: spin 2s linear infinite, pulse 1.5s ease-in-out infinite; + position: border-box; +} + +.loader::before, +.loader::after { + content: ""; + position: absolute; + top: 6px; + bottom: 6px; + left: 6px; + right: 6px; + border-radius: 50%; + border: var(--spacing-md) solid transparent; +} + +.loader::before { + border-top-color: var(--primary-900); + animation: spin 3s linear infinite; +} + +.loader::after { + border-top-color: var(--primary-300); + animation: spin 1.5s linear infinite; +} + +@keyframes move { + 0% { + transform: translateY(0); + } + 50% { + transform: translateY(-10px); + } + 100% { + transform: translateY(0); + } +} + +@keyframes spin { + from { + transform: rotate(0deg); + } + to { + transform: rotate(360deg); + } +} + +@keyframes pulse { + 0%, 100% { + transform: scale(1); + } + 50% { + transform: scale(1.1); + } +} + +@keyframes color { + 0% { + filter: hue-rotate(0deg); + } + 100% { + filter: hue-rotate(360deg); + } +} + +/* Token counters styling */ +#txt2img_token_counter, #txt2img_negative_token_counter { + display: flex; + flex-direction: row; + padding-top: 1px; + opacity: 0.6; + z-index: 99; +} + +#txt2img_prompt_container { + margin: 5px; + padding: 0px; +} + +#text2img_prompt label, #text2img_neg_prompt label { + margin: 0px; +} + +/* Based on Gradio Built-in Dark Theme */ +:root, +.light, +.dark { + --body-background-fill: var(--background-color); + --color-accent-soft: var(--neutral-700); + --background-fill-secondary: none; + --border-color-accent: var(--background-color); + --border-color-primary: var(--background-color); + --link-text-color-active: var(--primary-500); + --link-text-color: var(--secondary-500); + --link-text-color-hover: var(--secondary-400); + --link-text-color-visited: var(--secondary-600); + --shadow-spread: 1px; + --block-background-fill: none; + --block-border-color: var(--border-color-primary); + --block_border_width: none; + --block-info-text-color: var(--body-text-color-subdued); + --block-label-background-fill: var(--background-fill-secondary); + --block-label-border-color: var(--border-color-primary); + --block_label_border_width: none; + --block-label-text-color: var(--neutral-200); + --block-shadow: none; + --block-title-background-fill: none; + --block-title-border-color: none; + --block-title-border-width: 0px; + --block-title-padding: 0; + --block-title-radius: none; + --block-title-text-size: var(--text-md); + --block-title-text-weight: 400; + --container-radius: var(--radius-lg); + --form-gap-width: 1px; + --layout-gap: var(--spacing-xxl); + --panel-border-width: 0; + --section-header-text-size: var(--text-md); + --section-header-text-weight: 400; + --checkbox-border-radius: var(--radius-sm); + --checkbox-label-gap: 2px; + --checkbox-label-padding: var(--spacing-md); + --checkbox-label-shadow: var(--shadow-drop); + --checkbox-label-text-size: var(--text-md); + --checkbox-label-text-weight: 400; + --checkbox-check: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e"); + --radio-circle: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e"); + --checkbox-shadow: var(--input-shadow); + --error-border-width: 1px; + --input-border-width: 0; + --input-radius: var(--radius-lg); + --input-text-size: var(--text-md); + --input-text-weight: 400; + --prose-text-size: var(--text-md); + --prose-text-weight: 400; + --prose-header-text-weight: 400; + --slider-color: var(--neutral-900); + --table-radius: var(--radius-lg); + --button-large-padding: 2px 6px; + --button-large-radius: var(--radius-lg); + --button-large-text-size: var(--text-lg); + --button-large-text-weight: 400; + --button-shadow: none; + --button-shadow-active: none; + --button-shadow-hover: none; + --button-small-padding: var(--spacing-sm) calc(2 * var(--spacing-sm)); + --button-small-radius: var(--radius-lg); + --button-small-text-size: var(--text-md); + --button-small-text-weight: 400; + --button-transition: none; + --size-9: 64px; + --size-14: 64px; +} \ No newline at end of file diff --git a/javascript/black-teal.css b/javascript/black-teal.css index c6f266c54..2ebf32e96 100644 --- a/javascript/black-teal.css +++ b/javascript/black-teal.css @@ -134,7 +134,7 @@ svg.feather.feather-image, .feather .feather-image { display: none } .gallery-item { box-shadow: none !important; } .performance { color: #888; } .extra-networks { border-left: 2px solid var(--highlight-color) !important; padding-left: 4px; } -.image-buttons { gap: 10px !important; justify-content: center; } +.image-buttons { justify-content: center; gap: 0 !important; } .image-buttons > button { max-width: 160px; } .tooltip { background: var(--primary-300); color: black; border: none; border-radius: var(--radius-lg) } #system_row > button, #settings_row > button, #config_row > button { max-width: 10em; } diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 77fe125f3..1d1bcfb24 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -3,19 +3,6 @@ let sortVal = -1; // helpers -const requestGet = (url, data, handler) => { - const xhr = new XMLHttpRequest(); - const args = Object.keys(data).map((k) => `${encodeURIComponent(k)}=${encodeURIComponent(data[k])}`).join('&'); - xhr.open('GET', `${url}?${args}`, true); - xhr.onreadystatechange = () => { - if (xhr.readyState === 4) { - if (xhr.status === 200) handler(JSON.parse(xhr.responseText)); - else console.error(`Request: url=${url} status=${xhr.status} err`); - } - }; - xhr.send(JSON.stringify(data)); -}; - const getENActiveTab = () => { let tabName = ''; if (gradioApp().getElementById('tab_txt2img').style.display === 'block') tabName = 'txt2img'; @@ -98,7 +85,7 @@ function readCardTags(el, tags) { } function readCardDescription(page, item) { - requestGet('/sd_extra_networks/description', { page, item }, (data) => { + xhrGet('/sd_extra_networks/description', { page, item }, (data) => { const tabname = getENActiveTab(); const description = gradioApp().querySelector(`#${tabname}_description > label > textarea`); description.value = data?.description?.trim() || ''; @@ -447,6 +434,22 @@ function setupExtraNetworksForTab(tabname) { }; } + // auto-resize networks sidebar + const resizeObserver = new ResizeObserver((entries) => { + for (const entry of entries) { + for (const el of Array.from(gradioApp().getElementById(`${tabname}_extra_tabs`).querySelectorAll('.extra-networks-page'))) { + const h = Math.trunc(entry.contentRect.height); + if (h <= 0) return; + if (window.opts.extra_networks_card_cover === 'sidebar' && window.opts.theme_type === 'Standard') el.style.height = `max(55vh, ${h - 90}px)`; + // log(`${tabname} height: ${entry.target.id}=${h} ${el.id}=${el.clientHeight}`); + } + } + }); + const settingsEl = gradioApp().getElementById(`${tabname}_settings`); + const interfaceEl = gradioApp().getElementById(`${tabname}_interface`); + if (settingsEl) resizeObserver.observe(settingsEl); + if (interfaceEl) resizeObserver.observe(interfaceEl); + // en style if (!en) return; let lastView; diff --git a/javascript/gallery.js b/javascript/gallery.js index 1f3afd148..05e594e4c 100644 --- a/javascript/gallery.js +++ b/javascript/gallery.js @@ -94,14 +94,14 @@ async function delayFetchThumb(fn) { outstanding++; const res = await fetch(`/sdapi/v1/browser/thumb?file=${encodeURI(fn)}`, { priority: 'low' }); if (!res.ok) { - console.error(res.statusText); + error(`fetchThumb: ${res.statusText}`); outstanding--; return undefined; } const json = await res.json(); outstanding--; if (!res || !json || json.error || Object.keys(json).length === 0) { - if (json.error) console.error(json.error); + if (json.error) error(`fetchThumb: ${json.error}`); return undefined; } return json; diff --git a/javascript/imageMaskFix.js b/javascript/imageMaskFix.js deleted file mode 100644 index fd37caf90..000000000 --- a/javascript/imageMaskFix.js +++ /dev/null @@ -1,38 +0,0 @@ -/** - * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668 - * @see https://github.com/gradio-app/gradio/issues/1721 - */ -function imageMaskResize() { - const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); - if (!canvases.length) { - window.removeEventListener('resize', imageMaskResize); - return; - } - const wrapper = canvases[0].closest('.touch-none'); - const previewImage = wrapper.previousElementSibling; - if (!previewImage.complete) { - previewImage.addEventListener('load', imageMaskResize); - return; - } - const w = previewImage.width; - const h = previewImage.height; - const nw = previewImage.naturalWidth; - const nh = previewImage.naturalHeight; - const portrait = nh > nw; - const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw); - const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh); - wrapper.style.width = `${wW}px`; - wrapper.style.height = `${wH}px`; - wrapper.style.left = '0px'; - wrapper.style.top = '0px'; - canvases.forEach((c) => { - c.style.width = ''; - c.style.height = ''; - c.style.maxWidth = '100%'; - c.style.maxHeight = '100%'; - c.style.objectFit = 'contain'; - }); -} - -onAfterUiUpdate(imageMaskResize); -window.addEventListener('resize', imageMaskResize); diff --git a/javascript/light-teal.css b/javascript/light-teal.css index 28bf03e6f..174622e52 100644 --- a/javascript/light-teal.css +++ b/javascript/light-teal.css @@ -20,9 +20,9 @@ --body-text-color: var(--neutral-800); --body-text-color-subdued: var(--neutral-600); --background-color: #FFFFFF; - --background-fill-primary: var(--neutral-400); + --background-fill-primary: var(--neutral-300); --input-padding: 4px; - --input-background-fill: var(--neutral-300); + --input-background-fill: var(--neutral-200); --input-shadow: 2px 2px 2px 2px var(--neutral-500); --button-secondary-text-color: black; --button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-200), var(--neutral-500)); @@ -291,8 +291,8 @@ svg.feather.feather-image, .feather .feather-image { display: none } --slider-color: ; --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600)); --table-border-color: var(--neutral-700); - --table-even-background-fill: #222222; - --table-odd-background-fill: #333333; + --table-even-background-fill: #FFFFFF; + --table-odd-background-fill: #CCCCCC; --table-radius: var(--radius-lg); --table-row-focus: var(--color-accent-soft); } diff --git a/javascript/loader.js b/javascript/loader.js index f3c7fe60f..8cd4811bf 100644 --- a/javascript/loader.js +++ b/javascript/loader.js @@ -20,7 +20,7 @@ async function preloadImages() { try { await Promise.all(imagePromises); } catch (error) { - console.error('Error preloading images:', error); + error(`preloadImages: ${error}`); } } @@ -43,14 +43,16 @@ async function createSplash() { const motdEl = document.getElementById('motd'); if (motdEl) motdEl.innerHTML = text.replace(/["]+/g, ''); }) - .catch((err) => console.error('getMOTD:', err)); + .catch((err) => error(`getMOTD: ${err}`)); } async function removeSplash() { const splash = document.getElementById('splash'); if (splash) splash.remove(); log('removeSplash'); - log('startupTime', Math.round(performance.now() - appStartTime) / 1000); + const t = Math.round(performance.now() - appStartTime) / 1000; + log('startupTime', t); + xhrPost('/sdapi/v1/log', { message: `ready time=${t}` }); } window.onload = createSplash; diff --git a/javascript/logMonitor.js b/javascript/logMonitor.js index e4fe99a7f..9b915e6da 100644 --- a/javascript/logMonitor.js +++ b/javascript/logMonitor.js @@ -2,6 +2,7 @@ let logMonitorEl = null; let logMonitorStatus = true; let logWarnings = 0; let logErrors = 0; +let logConnected = false; function dateToStr(ts) { const dt = new Date(1000 * ts); @@ -29,8 +30,7 @@ async function logMonitor() { row.innerHTML = `${dateToStr(l.created)}${level}${l.facility}${module}${l.msg}`; logMonitorEl.appendChild(row); } catch (e) { - // console.log('logMonitor', e); - console.error('logMonitor line', line); + error(`logMonitor: ${line}`); } }; @@ -46,6 +46,7 @@ async function logMonitor() { if (logMonitorStatus) setTimeout(logMonitor, opts.logmonitor_refresh_period); else setTimeout(logMonitor, 10 * 1000); // on failure try to reconnect every 10sec + if (!opts.logmonitor_show) return; logMonitorStatus = false; if (!logMonitorEl) { @@ -64,14 +65,20 @@ async function logMonitor() { const lines = await res.json(); if (logMonitorEl && lines?.length > 0) logMonitorEl.parentElement.parentElement.style.display = opts.logmonitor_show ? 'block' : 'none'; for (const line of lines) addLogLine(line); + if (!logConnected) { + logConnected = true; + xhrPost('/sdapi/v1/log', { debug: 'connected' }); + } } else { - addLogLine(`{ "created": ${Date.now()}, "level":"ERROR", "module":"logMonitor", "facility":"ui", "msg":"Failed to fetch log: ${res?.status} ${res?.statusText}" }`); + logConnected = false; logErrors++; + addLogLine(`{ "created": ${Date.now()}, "level":"ERROR", "module":"logMonitor", "facility":"ui", "msg":"Failed to fetch log: ${res?.status} ${res?.statusText}" }`); } cleanupLog(atBottom); } catch (err) { - addLogLine(`{ "created": ${Date.now()}, "level":"ERROR", "module":"logMonitor", "facility":"ui", "msg":"Failed to fetch log: server unreachable" }`); + logConnected = false; logErrors++; + addLogLine(`{ "created": ${Date.now()}, "level":"ERROR", "module":"logMonitor", "facility":"ui", "msg":"Failed to fetch log: server unreachable" }`); cleanupLog(atBottom); } } diff --git a/javascript/logger.js b/javascript/logger.js new file mode 100644 index 000000000..8fa812b86 --- /dev/null +++ b/javascript/logger.js @@ -0,0 +1,68 @@ +const timeout = 10000; + +const log = async (...msg) => { + const dt = new Date(); + const ts = `${dt.getHours().toString().padStart(2, '0')}:${dt.getMinutes().toString().padStart(2, '0')}:${dt.getSeconds().toString().padStart(2, '0')}.${dt.getMilliseconds().toString().padStart(3, '0')}`; + if (window.logger) window.logger.innerHTML += window.logPrettyPrint(...msg); + console.log(ts, ...msg); // eslint-disable-line no-console +}; + +const debug = async (...msg) => { + const dt = new Date(); + const ts = `${dt.getHours().toString().padStart(2, '0')}:${dt.getMinutes().toString().padStart(2, '0')}:${dt.getSeconds().toString().padStart(2, '0')}.${dt.getMilliseconds().toString().padStart(3, '0')}`; + if (window.logger) window.logger.innerHTML += window.logPrettyPrint(...msg); + console.debug(ts, ...msg); // eslint-disable-line no-console +}; + +const error = async (...msg) => { + const dt = new Date(); + const ts = `${dt.getHours().toString().padStart(2, '0')}:${dt.getMinutes().toString().padStart(2, '0')}:${dt.getSeconds().toString().padStart(2, '0')}.${dt.getMilliseconds().toString().padStart(3, '0')}`; + if (window.logger) window.logger.innerHTML += window.logPrettyPrint(...msg); + console.error(ts, ...msg); // eslint-disable-line no-console + // const txt = msg.join(' '); + // if (!txt.includes('asctime') && !txt.includes('xhr.')) xhrPost('/sdapi/v1/log', { error: txt }); // eslint-disable-line no-use-before-define +}; + +const xhrInternal = (xhrObj, data, handler = undefined, errorHandler = undefined, ignore = false, serverTimeout = timeout) => { + const err = (msg) => { + if (!ignore) { + error(`${msg}: state=${xhrObj.readyState} status=${xhrObj.status} response=${xhrObj.responseText}`); + if (errorHandler) errorHandler(xhrObj); + } + }; + + xhrObj.setRequestHeader('Content-Type', 'application/json'); + xhrObj.timeout = timeout; + xhrObj.ontimeout = () => err('xhr.ontimeout'); + xhrObj.onerror = () => err('xhr.onerror'); + xhrObj.onabort = () => err('xhr.onabort'); + xhrObj.onreadystatechange = () => { + if (xhrObj.readyState === 4) { + if (xhrObj.status === 200) { + try { + const json = JSON.parse(xhrObj.responseText); + if (handler) handler(json); + } catch (e) { + error(`xhr.onreadystatechange: ${e}`); + } + } else { + err(`xhr.onreadystatechange: state=${xhrObj.readyState} status=${xhrObj.status} response=${xhrObj.responseText}`); + } + } + }; + const req = JSON.stringify(data); + xhrObj.send(req); +}; + +const xhrGet = (url, data, handler = undefined, errorHandler = undefined, ignore = false, serverTimeout = timeout) => { + const xhr = new XMLHttpRequest(); + const args = Object.keys(data).map((k) => `${encodeURIComponent(k)}=${encodeURIComponent(data[k])}`).join('&'); + xhr.open('GET', `${url}?${args}`, true); + xhrInternal(xhr, data, handler, errorHandler, ignore, serverTimeout); +}; + +function xhrPost(url, data, handler = undefined, errorHandler = undefined, ignore = false, serverTimeout = timeout) { + const xhr = new XMLHttpRequest(); + xhr.open('POST', url, true); + xhrInternal(xhr, data, handler, errorHandler, ignore, serverTimeout); +} diff --git a/javascript/notification.js b/javascript/notification.js index 33e8d1c55..c702c90e7 100644 --- a/javascript/notification.js +++ b/javascript/notification.js @@ -4,28 +4,32 @@ let lastHeadImg = null; let notificationButton = null; async function sendNotification() { - if (!notificationButton) { - notificationButton = gradioApp().getElementById('request_notifications'); - if (notificationButton) notificationButton.addEventListener('click', (evt) => Notification.requestPermission(), true); + try { + if (!notificationButton) { + notificationButton = gradioApp().getElementById('request_notifications'); + if (notificationButton) notificationButton.addEventListener('click', (evt) => Notification.requestPermission(), true); + } + if (document.hasFocus()) return; // window is in focus so don't send notifications + let galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img'); + if (!galleryPreviews || galleryPreviews.length === 0) galleryPreviews = gradioApp().querySelectorAll('.thumbnail-item > img'); + if (!galleryPreviews || galleryPreviews.length === 0) return; + const headImg = galleryPreviews[0]?.src; + if (!headImg || headImg === lastHeadImg || headImg.includes('logo-bg-')) return; + const audioNotification = gradioApp().querySelector('#audio_notification audio'); + if (audioNotification) audioNotification.play(); + lastHeadImg = headImg; + const imgs = new Set(Array.from(galleryPreviews).map((img) => img.src)); // Multiple copies of the images are in the DOM when one is selected + const notification = new Notification('SD.Next', { + body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`, + icon: headImg, + image: headImg, + }); + notification.onclick = () => { + parent.focus(); + this.close(); + }; + log('sendNotifications'); + } catch (e) { + error(`sendNotification: ${e}`); } - if (document.hasFocus()) return; // window is in focus so don't send notifications - let galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img'); - if (!galleryPreviews || galleryPreviews.length === 0) galleryPreviews = gradioApp().querySelectorAll('.thumbnail-item > img'); - if (!galleryPreviews || galleryPreviews.length === 0) return; - const headImg = galleryPreviews[0]?.src; - if (!headImg || headImg === lastHeadImg || headImg.includes('logo-bg-')) return; - const audioNotification = gradioApp().querySelector('#audio_notification audio'); - if (audioNotification) audioNotification.play(); - lastHeadImg = headImg; - const imgs = new Set(Array.from(galleryPreviews).map((img) => img.src)); // Multiple copies of the images are in the DOM when one is selected - const notification = new Notification('SD.Next', { - body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`, - icon: headImg, - image: headImg, - }); - notification.onclick = () => { - parent.focus(); - this.close(); - }; - log('sendNotifications'); } diff --git a/javascript/nvml.js b/javascript/nvml.js index cf0187367..0a82cba1b 100644 --- a/javascript/nvml.js +++ b/javascript/nvml.js @@ -1,6 +1,32 @@ let nvmlInterval = null; // eslint-disable-line prefer-const let nvmlEl = null; let nvmlTable = null; +const chartData = { mem: [], load: [] }; + +async function updateNVMLChart(mem, load) { + const maxLen = 120; + const colorRangeMap = $.range_map({ // eslint-disable-line no-undef + '0:5': '#fffafa', + '6:10': '#fff7ed', + '11:20': '#fed7aa', + '21:30': '#fdba74', + '31:40': '#fb923c', + '41:50': '#f97316', + '51:60': '#ea580c', + '61:70': '#c2410c', + '71:80': '#9a3412', + '81:90': '#7c2d12', + '91:100': '#6c2e12', + }); + const sparklineConfigLOAD = { type: 'bar', height: '100px', barWidth: '2px', barSpacing: '1px', chartRangeMin: 0, chartRangeMax: 100, barColor: '#89007D' }; + const sparklineConfigMEM = { type: 'bar', height: '100px', barWidth: '2px', barSpacing: '1px', chartRangeMin: 0, chartRangeMax: 100, colorMap: colorRangeMap, composite: true }; + if (chartData.load.length > maxLen) chartData.load.shift(); + chartData.load.push(load); + if (chartData.mem.length > maxLen) chartData.mem.shift(); + chartData.mem.push(mem); + $('#nvmlChart').sparkline(chartData.load, sparklineConfigLOAD); // eslint-disable-line no-undef + $('#nvmlChart').sparkline(chartData.mem, sparklineConfigMEM); // eslint-disable-line no-undef +} async function updateNVML() { try { @@ -35,6 +61,7 @@ async function updateNVML() { State${gpu.state} `; nvmlTbody.innerHTML = rows; + updateNVMLChart(gpu.load.memory, gpu.load.gpu); } nvmlEl.style.display = 'block'; } catch (e) { @@ -56,7 +83,10 @@ async function initNVML() { `; + const nvmlChart = document.createElement('div'); + nvmlChart.id = 'nvmlChart'; nvmlEl.appendChild(nvmlTable); + nvmlEl.appendChild(nvmlChart); gradioApp().appendChild(nvmlEl); log('initNVML'); } diff --git a/javascript/progressBar.js b/javascript/progressBar.js index a9ecb31e9..dfd895f8e 100644 --- a/javascript/progressBar.js +++ b/javascript/progressBar.js @@ -1,28 +1,5 @@ let lastState = {}; -function request(url, data, handler, errorHandler) { - const xhr = new XMLHttpRequest(); - xhr.open('POST', url, true); - xhr.setRequestHeader('Content-Type', 'application/json'); - xhr.onreadystatechange = () => { - if (xhr.readyState === 4) { - if (xhr.status === 200) { - try { - const js = JSON.parse(xhr.responseText); - handler(js); - } catch (error) { - console.error(error); - errorHandler(); - } - } else { - errorHandler(); - } - } - }; - const js = JSON.stringify(data); - xhr.send(js); -} - function pad2(x) { return x < 10 ? `0${x}` : x; } @@ -35,8 +12,10 @@ function formatTime(secs) { function checkPaused(state) { lastState.paused = state ? !state : !lastState.paused; - document.getElementById('txt2img_pause').innerText = lastState.paused ? 'Resume' : 'Pause'; - document.getElementById('img2img_pause').innerText = lastState.paused ? 'Resume' : 'Pause'; + const t_el = document.getElementById('txt2img_pause'); + const i_el = document.getElementById('img2img_pause'); + if (t_el) t_el.innerText = lastState.paused ? 'Resume' : 'Pause'; + if (i_el) i_el.innerText = lastState.paused ? 'Resume' : 'Pause'; } function setProgress(res) { @@ -89,28 +68,42 @@ function requestProgress(id_task, progressEl, galleryEl, atEnd = null, onProgres let img; const initLivePreview = () => { - img = new Image(); - if (parentGallery) { - livePreview = document.createElement('div'); - livePreview.className = 'livePreview'; - parentGallery.insertBefore(livePreview, galleryEl); - const rect = galleryEl.getBoundingClientRect(); - if (rect.width) { - livePreview.style.width = `${rect.width}px`; - livePreview.style.height = `${rect.height}px`; - } - img.onload = () => { - livePreview.appendChild(img); - if (livePreview.childElementCount > 2) livePreview.removeChild(livePreview.firstElementChild); - }; + if (!parentGallery) return; + const footers = Array.from(gradioApp().querySelectorAll('.gallery_footer')); + for (const footer of footers) { + if (footer.id !== 'gallery_footer') footer.style.display = 'none'; // remove all footers + } + const galleries = Array.from(gradioApp().querySelectorAll('.gallery_main')); + for (const gallery of galleries) { + if (gallery.id !== 'gallery_gallery') gallery.style.display = 'none'; // remove all footers } + + livePreview = document.createElement('div'); + livePreview.className = 'livePreview'; + parentGallery.insertBefore(livePreview, galleryEl); + img = new Image(); + img.id = 'livePreviewImage'; + livePreview.appendChild(img); + img.onload = () => { + img.style.width = `min(100%, max(${img.naturalWidth}px, 512px))`; + parentGallery.style.minHeight = `${img.height}px`; + }; }; const done = () => { debug('taskEnd:', id_task); localStorage.removeItem('task'); setProgress(); - if (parentGallery && livePreview) parentGallery.removeChild(livePreview); + const footers = Array.from(gradioApp().querySelectorAll('.gallery_footer')); + for (const footer of footers) footer.style.display = 'flex'; // restore all footers + const galleries = Array.from(gradioApp().querySelectorAll('.gallery_main')); + for (const gallery of galleries) gallery.style.display = 'flex'; // remove all galleries + try { + if (parentGallery && livePreview) { + parentGallery.removeChild(livePreview); + parentGallery.style.minHeight = 'unset'; + } + } catch { /* ignore */ } checkPaused(true); sendNotification(); if (atEnd) atEnd(); @@ -118,20 +111,32 @@ function requestProgress(id_task, progressEl, galleryEl, atEnd = null, onProgres const start = (id_task, id_live_preview) => { // eslint-disable-line no-shadow if (!opts.live_previews_enable || opts.live_preview_refresh_period === 0 || opts.show_progress_every_n_steps === 0) return; - request('./internal/progress', { id_task, id_live_preview }, (res) => { + + const onProgressHandler = (res) => { + // debug('onProgress', res); lastState = res; const elapsedFromStart = (new Date() - dateStart) / 1000; hasStarted |= res.active; if (res.completed || (!res.active && (hasStarted || once)) || (elapsedFromStart > 30 && !res.queued && res.progress === prevProgress)) { + debug('onProgressEnd', res); done(); return; } setProgress(res); if (res.live_preview && !livePreview) initLivePreview(); - if (res.live_preview && galleryEl) img.src = res.live_preview; + if (res.live_preview && galleryEl) { + if (img.src !== res.live_preview) img.src = res.live_preview; + } if (onProgress) onProgress(res); setTimeout(() => start(id_task, id_live_preview), opts.live_preview_refresh_period || 500); - }, done); + }; + + const onProgressErrorHandler = (err) => { + error(`onProgressError: ${err}`); + done(); + }; + + xhrPost('./internal/progress', { id_task, id_live_preview }, onProgressHandler, onProgressErrorHandler, false, 5000); }; start(id_task, 0); } diff --git a/javascript/script.js b/javascript/script.js index 104567dd7..f943f4626 100644 --- a/javascript/script.js +++ b/javascript/script.js @@ -1,17 +1,3 @@ -const log = (...msg) => { - const dt = new Date(); - const ts = `${dt.getHours().toString().padStart(2, '0')}:${dt.getMinutes().toString().padStart(2, '0')}:${dt.getSeconds().toString().padStart(2, '0')}.${dt.getMilliseconds().toString().padStart(3, '0')}`; - if (window.logger) window.logger.innerHTML += window.logPrettyPrint(...msg); - console.log(ts, ...msg); // eslint-disable-line no-console -}; - -const debug = (...msg) => { - const dt = new Date(); - const ts = `${dt.getHours().toString().padStart(2, '0')}:${dt.getMinutes().toString().padStart(2, '0')}:${dt.getSeconds().toString().padStart(2, '0')}.${dt.getMilliseconds().toString().padStart(3, '0')}`; - if (window.logger) window.logger.innerHTML += window.logPrettyPrint(...msg); - console.debug(ts, ...msg); // eslint-disable-line no-console -}; - async function sleep(ms) { return new Promise((resolve) => setTimeout(resolve, ms)); // eslint-disable-line no-promise-executor-return } @@ -82,7 +68,7 @@ function executeCallbacks(queue, arg) { try { callback(arg); } catch (e) { - console.error('error running callback', callback, ':', e); + error(`executeCallbacks: ${callback} ${e}`); } } } @@ -139,11 +125,12 @@ document.addEventListener('keydown', (e) => { let elem; if (e.key === 'Escape') elem = getUICurrentTabContent().querySelector('button[id$=_interrupt]'); if (e.key === 'Enter' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id$=_generate]'); - if (e.key === 'Backspace' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id$=_reprocess]'); + if (e.key === 'i' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id$=_reprocess]'); if (e.key === ' ' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id$=_extra_networks_btn]'); + if (e.key === 'n' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id$=_extra_networks_btn]'); if (e.key === 's' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id^=save_]'); if (e.key === 'Insert' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id^=save_]'); - if (e.key === 'Delete' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id^=delete_]'); + if (e.key === 'd' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id^=delete_]'); // if (e.key === 'm' && e.ctrlKey) elem = gradioApp().getElementById('setting_sd_model_checkpoint'); if (elem) { e.preventDefault(); diff --git a/javascript/sdnext.css b/javascript/sdnext.css index 08fae2eb8..6e33551ef 100644 --- a/javascript/sdnext.css +++ b/javascript/sdnext.css @@ -14,9 +14,9 @@ table { overflow-x: auto !important; overflow-y: auto !important; } td { border-bottom: none !important; padding: 0 0.5em !important; } tr { border-bottom: none !important; padding: 0 0.5em !important; } td > div > span { overflow-y: auto; max-height: 3em; overflow-x: hidden; } -textarea { overflow-y: auto !important; } +textarea { overflow-y: auto !important; border-radius: 4px !important; } span { font-size: var(--text-md) !important; } -button { font-size: var(--text-lg) !important; } +button { font-size: var(--text-lg) !important; min-width: unset !important; } input[type='color'] { width: 64px; height: 32px; } input::-webkit-outer-spin-button, input::-webkit-inner-spin-button { margin-left: 4px; } @@ -30,6 +30,19 @@ input::-webkit-outer-spin-button, input::-webkit-inner-spin-button { margin-left .hidden { display: none; } .tabitem { padding: 0 !important; } +/* gradio image/canvas elements */ +.image-container { overflow: auto; } +/* +.gradio-image { min-height: fit-content; } +.gradio-image img { object-fit: contain; } +*/ +/* +.gradio-image { min-height: 200px !important; } +.image-container { height: unset !important; } +.control-image { height: unset !important; } +#img2img_sketch, #img2maskimg, #inpaint_sketch { overflow: overlay !important; resize: auto; background: var(--panel-background-fill); z-index: 5; } +*/ + /* color elements */ .gradio-dropdown, .block.gradio-slider, .block.gradio-checkbox, .block.gradio-textbox, .block.gradio-radio, .block.gradio-checkboxgroup, .block.gradio-number, .block.gradio-colorpicker { border-width: 0 !important; box-shadow: none !important;} .gradio-accordion { padding-top: var(--spacing-md) !important; padding-right: 0 !important; padding-bottom: 0 !important; color: var(--body-text-color); } @@ -83,13 +96,12 @@ button.custom-button { border-radius: var(--button-large-radius); padding: var(- .block.token-counter div{ display: inline; } .block.token-counter span{ padding: 0.1em 0.75em; } .performance { font-size: var(--text-xs); color: #444; } -.performance p { display: inline-block; color: var(--body-text-color-subdued) !important } +.performance p { display: inline-block; color: var(--primary-500) !important } .performance .time { margin-right: 0; } .thumbnails { background: var(--body-background-fill); } -.control-image { height: calc(100vw/3) !important; } .prompt textarea { resize: vertical; } +.grid-wrap { overflow-y: auto !important; } #control_results { margin: 0; padding: 0; } -#control_gallery { height: calc(100vw/3 + 60px); } #txt2img_gallery, #img2img_gallery { height: 50vh; } #control-result { background: var(--button-secondary-background-fill); padding: 0.2em; } #control-inputs { margin-top: 1em; } @@ -105,7 +117,6 @@ button.custom-button { border-radius: var(--button-large-radius); padding: var(- #txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt, #control_prompt, #control_neg_prompt { display: contents; } #txt2img_actions_column, #img2img_actions_column, #control_actions { flex-flow: wrap; justify-content: space-between; } - .interrogate-clip { position: absolute; right: 6em; top: 8px; max-width: fit-content; background: none !important; z-index: 50; } .interrogate-blip { position: absolute; right: 4em; top: 8px; max-width: fit-content; background: none !important; z-index: 50; } .interrogate-col { min-width: 0 !important; max-width: fit-content; margin-right: var(--spacing-xxl); } @@ -118,11 +129,9 @@ div#extras_scale_to_tab div.form { flex-direction: row; } #img2img_unused_scale_by_slider { visibility: hidden; width: 0.5em; max-width: 0.5em; min-width: 0.5em; } .inactive{ opacity: 0.5; } div#extras_scale_to_tab div.form { flex-direction: row; } -#mode_img2img .gradio-image>div.fixed-height, #mode_img2img .gradio-image>div.fixed-height img{ height: 480px !important; max-height: 480px !important; min-height: 480px !important; } -#img2img_sketch, #img2maskimg, #inpaint_sketch { overflow: overlay !important; resize: auto; background: var(--panel-background-fill); z-index: 5; } .image-buttons button { min-width: auto; } .infotext { overflow-wrap: break-word; line-height: 1.5em; font-size: 0.95em !important; } -.infotext > p { padding-left: 1em; text-indent: -1em; white-space: pre-wrap; color: var(--block-info-text-color) !important; } +.infotext > p { white-space: pre-wrap; color: var(--block-info-text-color) !important; } .tooltip { display: block; position: fixed; top: 1em; right: 1em; padding: 0.5em; background: var(--input-background-fill); color: var(--body-text-color); border: 1pt solid var(--button-primary-border-color); width: 22em; min-height: 1.3em; font-size: var(--text-xs); transition: opacity 0.2s ease-in; pointer-events: none; opacity: 0; z-index: 999; } .tooltip-show { opacity: 0.9; } @@ -140,29 +149,30 @@ div#extras_scale_to_tab div.form { flex-direction: row; } #settings>div.tab-content { flex: 10 0 75%; display: grid; } #settings>div.tab-content>div { border: none; padding: 0; } #settings>div.tab-content>div>div>div>div>div { flex-direction: unset; } -#settings>div.tab-nav { display: grid; grid-template-columns: repeat(auto-fill, .5em minmax(10em, 1fr)); flex: 1 0 auto; width: 12em; align-self: flex-start; gap: var(--spacing-xxl); } +#settings>div.tab-nav { display: grid; grid-template-columns: repeat(auto-fill, .5em minmax(10em, 1fr)); flex: 1 0 auto; width: 12em; align-self: flex-start; gap: 8px; } #settings>div.tab-nav button { display: block; border: none; text-align: left; white-space: initial; padding: 0; } #settings>div.tab-nav>#settings_show_all_pages { padding: var(--size-2) var(--size-4); } #settings .block.gradio-checkbox { margin: 0; width: auto; } #settings .dirtyable { gap: .5em; } #settings .dirtyable.hidden { display: none; } -#settings .modification-indicator { height: 1.2em; border-radius: 1em !important; padding: 0; width: 0; margin-right: 0.5em; } +#settings .modification-indicator { height: 1.2em; border-radius: 1em !important; padding: 0; width: 0; margin-right: 0.5em; border-left: inset; } #settings .modification-indicator:disabled { visibility: hidden; } #settings .modification-indicator.saved { background: var(--color-accent-soft); width: var(--spacing-sm); } #settings .modification-indicator.changed { background: var(--color-accent); width: var(--spacing-sm); } #settings .modification-indicator.changed.unsaved { background-image: linear-gradient(var(--color-accent) 25%, var(--color-accent-soft) 75%); width: var(--spacing-sm); } #settings_result { margin: 0 1.2em; } +#tab_settings .gradio-slider, #tab_settings .gradio-dropdown { width: 300px !important; max-width: 300px; } +#tab_settings textarea { max-width: 500px; } .licenses { display: block !important; } /* live preview */ .progressDiv { position: relative; height: 20px; background: #b4c0cc; margin-bottom: -3px; } .dark .progressDiv { background: #424c5b; } .progressDiv .progress { width: 0%; height: 20px; background: #0060df; color: white; font-weight: bold; line-height: 20px; padding: 0 8px 0 0; text-align: right; overflow: visible; white-space: nowrap; padding: 0 0.5em; } -.livePreview { position: absolute; z-index: 50; background-color: transparent; width: -moz-available; width: -webkit-fill-available; } -.livePreview img { position: absolute; object-fit: contain; width: 100%; height: 100%; } -.dark .livePreview { background-color: rgb(17 24 39 / var(--tw-bg-opacity)); } +.livePreview { position: absolute; z-index: 50; width: -moz-available; width: -webkit-fill-available; height: 100%; background-color: var(--background-color); } +.livePreview img { object-fit: contain; width: 100%; justify-self: center; } .popup-metadata { color: white; background: #0000; display: inline-block; white-space: pre-wrap; font-size: var(--text-xxs); } - +.generating { animation: unset !important; border: unset !important; } /* fullpage image viewer */ #lightboxModal { display: none; position: fixed; z-index: 1001; left: 0; top: 0; width: 100%; height: 100%; overflow: hidden; background-color: rgba(20, 20, 20, 0.75); backdrop-filter: blur(6px); user-select: none; -webkit-user-select: none; flex-direction: row; font-family: 'NotoSans';} @@ -207,7 +217,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt .extra_networks_root { width: 0; position: absolute; height: auto; right: 0; top: 13em; z-index: 100; } /* default is sidebar view */ .extra-networks { background: var(--background-color); padding: var(--block-label-padding); } .extra-networks > div { margin: 0; border-bottom: none !important; gap: 0.3em 0; } -.extra-networks .second-line { display: flex; width: -moz-available; width: -webkit-fill-available; gap: 0.3em; box-shadow: var(--input-shadow); } +.extra-networks .second-line { display: flex; width: -moz-available; width: -webkit-fill-available; gap: 0.3em; box-shadow: var(--input-shadow); margin-bottom: 2px; } .extra-networks .search { flex: 1; height: 4em; } .extra-networks .description { flex: 3; } .extra-networks .tab-nav>button { margin-right: 0; height: 24px; padding: 2px 4px 2px 4px; } @@ -216,7 +226,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt .extra-networks .custom-button { width: 120px; width: 100%; background: none; justify-content: left; text-align: left; padding: 3px 3px 3px 12px; text-indent: -6px; box-shadow: none; line-break: auto; } .extra-networks .custom-button:hover { background: var(--button-primary-background-fill) } .extra-networks-tab { padding: 0 !important; } -.extra-network-subdirs { background: var(--input-background-fill); overflow-x: hidden; overflow-y: auto; min-width: max(15%, 120px); padding-top: 0.5em; margin-top: -4px !important; } +.extra-network-subdirs { background: var(--input-background-fill); overflow-x: hidden; overflow-y: auto; min-width: max(15%, 120px); padding-top: 0.5em; border-radius: 4px; } .extra-networks-page { display: flex } .extra-network-cards { display: flex; flex-wrap: wrap; overflow-y: auto; overflow-x: hidden; align-content: flex-start; width: -moz-available; width: -webkit-fill-available; } .extra-network-cards .card { height: fit-content; margin: 0 0 0.5em 0.5em; position: relative; scroll-snap-align: start; scroll-margin-top: 0; } @@ -380,8 +390,6 @@ div:has(>#tab-gallery-folders) { flex-grow: 0 !important; background-color: var( #img2img_actions_column { display: flex; min-width: fit-content !important; flex-direction: row;justify-content: space-evenly; align-items: center;} #txt2img_generate_box, #img2img_generate_box, #txt2img_enqueue_wrapper,#img2img_enqueue_wrapper {display: flex;flex-direction: column;height: 4em !important;align-items: stretch;justify-content: space-evenly;} #img2img_interface, #img2img_results, #img2img_footer p { text-wrap: wrap; min-width: 100% !important; max-width: 100% !important;} /* maintain single column for from image operations on larger mobile devices */ - #img2img_sketch, #img2maskimg, #inpaint_sketch {display: flex; overflow: auto !important; resize: none !important; } /* fix inpaint image display being too large for mobile displays */ - #img2maskimg canvas { width: auto !important; max-height: 100% !important; height: auto !important; } #txt2img_sampler, #txt2img_batch, #txt2img_seed_group, #txt2img_advanced, #txt2img_second_pass, #img2img_sampling_group, #img2img_resize_group, #img2img_batch_group, #img2img_seed_group, #img2img_denoise_group, #img2img_advanced_group { width: 100% !important; } /* fix from text/image UI elements to prevent them from moving around within the UI */ #img2img_resize_group .gradio-radio>div { display: flex; flex-direction: column; width: unset !important; } #inpaint_controls div { display:flex;flex-direction: row;} diff --git a/javascript/ui.js b/javascript/ui.js index 8808f1c8b..3e3f14390 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -28,7 +28,7 @@ function clip_gallery_urls(gallery) { const files = gallery.map((v) => v.data); navigator.clipboard.writeText(JSON.stringify(files)).then( () => log('clipboard:', files), - (err) => console.error('clipboard:', files, err), + (err) => error(`clipboard: ${files} ${err}`), ); } @@ -139,7 +139,7 @@ function switch_to_inpaint(...args) { return Array.from(arguments); } -function switch_to_inpaint_sketch(...args) { +function switch_to_composite(...args) { switchToTab('Image'); switch_to_img2img_tab(3); return Array.from(arguments); @@ -493,9 +493,9 @@ function previewTheme() { el.src = `/file=html/${name}.jpg`; } }) - .catch((e) => console.error('previewTheme:', e)); + .catch((e) => error(`previewTheme: ${e}`)); }) - .catch((e) => console.error('previewTheme:', e)); + .catch((e) => error(`previewTheme: ${e}`)); } async function browseFolder() { diff --git a/launch.py b/launch.py index f944a7e54..e00da58c7 100755 --- a/launch.py +++ b/launch.py @@ -55,9 +55,11 @@ def get_custom_args(): if 'PS1' in env: del env['PS1'] installer.log.trace(f'Environment: {installer.print_dict(env)}') - else: - env = [f'{k}={v}' for k, v in os.environ.items() if k.startswith('SD_')] - installer.log.debug(f'Env flags: {env}') + env = [f'{k}={v}' for k, v in os.environ.items() if k.startswith('SD_')] + installer.log.debug(f'Env flags: {env}') + ldd = os.environ.get('LD_PRELOAD', None) + if ldd is not None: + installer.log.debug(f'Linker flags: "{ldd}"') @lru_cache() diff --git a/models/Reference/Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg b/models/Reference/Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg new file mode 100644 index 000000000..654f85403 Binary files /dev/null and b/models/Reference/Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg differ diff --git a/modules/api/api.py b/modules/api/api.py index f8346995d..b958085ea 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -35,7 +35,8 @@ def __init__(self, app: FastAPI, queue_lock: Lock): # server api self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str) - self.add_api_route("/sdapi/v1/log", server.get_log_buffer, methods=["GET"], response_model=List[str]) + self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=List[str]) + self.add_api_route("/sdapi/v1/log", server.post_log, methods=["POST"]) self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"]) self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"]) self.add_api_route("/sdapi/v1/status", server.get_status, methods=["GET"], response_model=models.ResStatus) @@ -90,6 +91,11 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.add_api_route("/sdapi/v1/history", endpoints.get_history, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/history", endpoints.post_history, methods=["POST"], response_model=int) + # lora api + if shared.native: + self.add_api_route("/sdapi/v1/loras", endpoints.get_loras, methods=["GET"], response_model=List[dict]) + self.add_api_route("/sdapi/v1/refresh-loras", endpoints.post_refresh_loras, methods=["POST"]) + # gallery api gallery.register_api(app) diff --git a/modules/api/control.py b/modules/api/control.py index 29c5a77f1..411f71ff8 100644 --- a/modules/api/control.py +++ b/modules/api/control.py @@ -159,6 +159,8 @@ def post_control(self, req: ReqControl): output_images = [] output_processed = [] output_info = '' + # TODO control: support scripts via api + # init script args, call scripts.script_control.run, call scripts.script_control.after run.control_set({ 'do_not_save_grid': not req.save_images, 'do_not_save_samples': not req.save_images, **self.prepare_ip_adapter(req) }) run.control_set(getattr(req, "extra", {})) res = run.control_run(**args) diff --git a/modules/api/endpoints.py b/modules/api/endpoints.py index 61993db84..1c56b7171 100644 --- a/modules/api/endpoints.py +++ b/modules/api/endpoints.py @@ -40,6 +40,12 @@ def convert_embeddings(embeddings): return {"loaded": convert_embeddings(db.word_embeddings), "skipped": convert_embeddings(db.skipped_embeddings)} +def get_loras(): + from modules.lora import network, networks + def create_lora_json(obj: network.NetworkOnDisk): + return { "name": obj.name, "alias": obj.alias, "path": obj.filename, "metadata": obj.metadata } + return [create_lora_json(obj) for obj in networks.available_networks.values()] + def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, filename: Optional[str] = None, title: Optional[str] = None, fullname: Optional[str] = None, hash: Optional[str] = None): # pylint: disable=redefined-builtin res = [] for pg in shared.extra_networks: @@ -126,6 +132,10 @@ def post_refresh_checkpoints(): def post_refresh_vae(): return shared.refresh_vaes() +def post_refresh_loras(): + from modules.lora import networks + return networks.list_available_networks() + def get_extensions_list(): from modules import extensions extensions.list_extensions() diff --git a/modules/api/generate.py b/modules/api/generate.py index b8ee645a4..9b409a14b 100644 --- a/modules/api/generate.py +++ b/modules/api/generate.py @@ -116,6 +116,8 @@ def post_text2img(self, txt2imgreq: models.ReqTxt2Img): processed = scripts.scripts_txt2img.run(p, *script_args) # Need to pass args as list here else: processed = process_images(p) + processed = scripts.scripts_txt2img.after(p, processed, *script_args) + p.close() shared.state.end(api=False) if processed is None or processed.images is None or len(processed.images) == 0: b64images = [] @@ -166,6 +168,8 @@ def post_img2img(self, img2imgreq: models.ReqImg2Img): processed = scripts.scripts_img2img.run(p, *script_args) # Need to pass args as list here else: processed = process_images(p) + processed = scripts.scripts_img2img.after(p, processed, *script_args) + p.close() shared.state.end(api=False) if processed is None or processed.images is None or len(processed.images) == 0: b64images = [] diff --git a/modules/api/models.py b/modules/api/models.py index e68ebf081..39bcbe383 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -286,10 +286,16 @@ class ResImageInfo(BaseModel): items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had") parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields") -class ReqLog(BaseModel): +class ReqGetLog(BaseModel): lines: int = Field(default=100, title="Lines", description="How many lines to return") clear: bool = Field(default=False, title="Clear", description="Should the log be cleared after returning the lines?") + +class ReqPostLog(BaseModel): + message: Optional[str] = Field(title="Message", description="The info message to log") + debug: Optional[str] = Field(title="Debug message", description="The debug message to log") + error: Optional[str] = Field(title="Error message", description="The error message to log") + class ReqProgress(BaseModel): skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization") diff --git a/modules/api/server.py b/modules/api/server.py index 939e19c86..dabbe634c 100644 --- a/modules/api/server.py +++ b/modules/api/server.py @@ -37,12 +37,22 @@ def get_platform(): from modules.loader import get_packages as loader_get_packages return { **installer_get_platform(), **loader_get_packages() } -def get_log_buffer(req: models.ReqLog = Depends()): +def get_log(req: models.ReqGetLog = Depends()): lines = shared.log.buffer[:req.lines] if req.lines > 0 else shared.log.buffer.copy() if req.clear: shared.log.buffer.clear() return lines +def post_log(req: models.ReqPostLog): + if req.message is not None: + shared.log.info(f'UI: {req.message}') + if req.debug is not None: + shared.log.debug(f'UI: {req.debug}') + if req.error is not None: + shared.log.error(f'UI: {req.error}') + return {} + + def get_config(): options = {} for k in shared.opts.data.keys(): diff --git a/modules/call_queue.py b/modules/call_queue.py index 4065d13d9..11ba7b56e 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -2,7 +2,7 @@ import threading import time import cProfile -from modules import shared, progress, errors +from modules import shared, progress, errors, timer queue_lock = threading.Lock() @@ -73,15 +73,20 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): elapsed_m = int(elapsed // 60) elapsed_s = elapsed % 60 elapsed_text = f"{elapsed_m}m {elapsed_s:.2f}s" if elapsed_m > 0 else f"{elapsed_s:.2f}s" - vram_html = '' + summary = timer.process.summary(min_time=0.25, total=False).replace('=', ' ') + gpu = '' + cpu = '' if not shared.mem_mon.disabled: vram = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.read().items()} - if vram.get('active_peak', 0) > 0: - vram_html = " |

" - vram_html += f"GPU active {max(vram['active_peak'], vram['reserved_peak'])} MB reserved {vram['reserved']} | used {vram['used']} MB free {vram['free']} MB total {vram['total']} MB" - vram_html += f" | retries {vram['retries']} oom {vram['oom']}" if vram.get('retries', 0) > 0 or vram.get('oom', 0) > 0 else '' - vram_html += "

" + peak = max(vram['active_peak'], vram['reserved_peak'], vram['used']) + used = round(100.0 * peak / vram['total']) if vram['total'] > 0 else 0 + if used > 0: + gpu += f"| GPU {peak} MB {used}%" + gpu += f" | retries {vram['retries']} oom {vram['oom']}" if vram.get('retries', 0) > 0 or vram.get('oom', 0) > 0 else '' + ram = shared.ram_stats() + if ram['used'] > 0: + cpu += f"| RAM {ram['used']} GB {round(100.0 * ram['used'] / ram['total'])}%" if isinstance(res, list): - res[-1] += f"

Time: {elapsed_text}

{vram_html}
" + res[-1] += f"

Time: {elapsed_text} | {summary} {gpu} {cpu}

" return tuple(res) return f diff --git a/modules/consistory/consistory_unet_sdxl.py b/modules/consistory/consistory_unet_sdxl.py index 4dd9b42d2..940b4ba01 100644 --- a/modules/consistory/consistory_unet_sdxl.py +++ b/modules/consistory/consistory_unet_sdxl.py @@ -916,7 +916,6 @@ def forward( # 1. time timesteps = timestep if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): diff --git a/modules/control/run.py b/modules/control/run.py index 5d6343c98..8cecb93af 100644 --- a/modules/control/run.py +++ b/modules/control/run.py @@ -7,6 +7,7 @@ from modules.control import util # helper functions from modules.control import unit # control units from modules.control import processors # image preprocessors +from modules.control import tile # tiling module from modules.control.units import controlnet # lllyasviel ControlNet from modules.control.units import xs # VisLearn ControlNet-XS from modules.control.units import lite # Kohya ControlLLLite @@ -44,6 +45,167 @@ def terminate(msg): return msg +def set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits): + global pipe, instance # pylint: disable=global-statement + pipe = None + if has_models: + p.ops.append('control') + p.extra_generation_params["Control type"] = unit_type # overriden later with pretty-print + p.extra_generation_params["Control model"] = ';'.join([(m.model_id or '') for m in active_model if m.model is not None]) + p.extra_generation_params["Control conditioning"] = control_conditioning if isinstance(control_conditioning, list) else [control_conditioning] + p.extra_generation_params['Control start'] = control_guidance_start if isinstance(control_guidance_start, list) else [control_guidance_start] + p.extra_generation_params['Control end'] = control_guidance_end if isinstance(control_guidance_end, list) else [control_guidance_end] + p.extra_generation_params["Control conditioning"] = ';'.join([str(c) for c in p.extra_generation_params["Control conditioning"]]) + p.extra_generation_params['Control start'] = ';'.join([str(c) for c in p.extra_generation_params['Control start']]) + p.extra_generation_params['Control end'] = ';'.join([str(c) for c in p.extra_generation_params['Control end']]) + if unit_type == 't2i adapter' and has_models: + p.extra_generation_params["Control type"] = 'T2I-Adapter' + p.task_args['adapter_conditioning_scale'] = control_conditioning + instance = t2iadapter.AdapterPipeline(selected_models, shared.sd_model) + pipe = instance.pipeline + if inits is not None: + shared.log.warning('Control: T2I-Adapter does not support separate init image') + elif unit_type == 'controlnet' and has_models: + p.extra_generation_params["Control type"] = 'ControlNet' + p.task_args['controlnet_conditioning_scale'] = control_conditioning + p.task_args['control_guidance_start'] = control_guidance_start + p.task_args['control_guidance_end'] = control_guidance_end + p.task_args['guess_mode'] = p.guess_mode + instance = controlnet.ControlNetPipeline(selected_models, shared.sd_model, p=p) + pipe = instance.pipeline + elif unit_type == 'xs' and has_models: + p.extra_generation_params["Control type"] = 'ControlNet-XS' + p.controlnet_conditioning_scale = control_conditioning + p.control_guidance_start = control_guidance_start + p.control_guidance_end = control_guidance_end + instance = xs.ControlNetXSPipeline(selected_models, shared.sd_model) + pipe = instance.pipeline + if inits is not None: + shared.log.warning('Control: ControlNet-XS does not support separate init image') + elif unit_type == 'lite' and has_models: + p.extra_generation_params["Control type"] = 'ControlLLLite' + p.controlnet_conditioning_scale = control_conditioning + instance = lite.ControlLLitePipeline(shared.sd_model) + pipe = instance.pipeline + if inits is not None: + shared.log.warning('Control: ControlLLLite does not support separate init image') + elif unit_type == 'reference' and has_models: + p.extra_generation_params["Control type"] = 'Reference' + p.extra_generation_params["Control attention"] = p.attention + p.task_args['reference_attn'] = 'Attention' in p.attention + p.task_args['reference_adain'] = 'Adain' in p.attention + p.task_args['attention_auto_machine_weight'] = p.query_weight + p.task_args['gn_auto_machine_weight'] = p.adain_weight + p.task_args['style_fidelity'] = p.fidelity + instance = reference.ReferencePipeline(shared.sd_model) + pipe = instance.pipeline + if inits is not None: + shared.log.warning('Control: ControlNet-XS does not support separate init image') + else: # run in txt2img/img2img mode + if len(active_strength) > 0: + p.strength = active_strength[0] + pipe = shared.sd_model + instance = None + debug(f'Control: run type={unit_type} models={has_models} pipe={pipe.__class__.__name__ if pipe is not None else None}') + return pipe + + +def check_active(p, unit_type, units): + active_process: List[processors.Processor] = [] # all active preprocessors + active_model: List[Union[controlnet.ControlNet, xs.ControlNetXS, t2iadapter.Adapter]] = [] # all active models + active_strength: List[float] = [] # strength factors for all active models + active_start: List[float] = [] # start step for all active models + active_end: List[float] = [] # end step for all active models + num_units = 0 + for u in units: + if u.type != unit_type: + continue + num_units += 1 + debug(f'Control unit: i={num_units} type={u.type} enabled={u.enabled}') + if not u.enabled: + if u.controlnet is not None and u.controlnet.model is not None: + debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.cpu}') + sd_models.move_model(u.controlnet.model, devices.cpu) + continue + if u.controlnet is not None and u.controlnet.model is not None: + debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.device}') + sd_models.move_model(u.controlnet.model, devices.device) + if unit_type == 't2i adapter' and u.adapter.model is not None: + active_process.append(u.process) + active_model.append(u.adapter) + active_strength.append(float(u.strength)) + p.adapter_conditioning_factor = u.factor + shared.log.debug(f'Control T2I-Adapter unit: i={num_units} process="{u.process.processor_id}" model="{u.adapter.model_id}" strength={u.strength} factor={u.factor}') + elif unit_type == 'controlnet' and u.controlnet.model is not None: + active_process.append(u.process) + active_model.append(u.controlnet) + active_strength.append(float(u.strength)) + active_start.append(float(u.start)) + active_end.append(float(u.end)) + p.guess_mode = u.guess + if isinstance(u.mode, str): + p.control_mode = u.choices.index(u.mode) if u.mode in u.choices else 0 + p.is_tile = p.is_tile or 'tile' in u.mode.lower() + p.control_tile = u.tile + p.extra_generation_params["Control mode"] = u.mode + shared.log.debug(f'Control ControlNet unit: i={num_units} process="{u.process.processor_id}" model="{u.controlnet.model_id}" strength={u.strength} guess={u.guess} start={u.start} end={u.end} mode={u.mode}') + elif unit_type == 'xs' and u.controlnet.model is not None: + active_process.append(u.process) + active_model.append(u.controlnet) + active_strength.append(float(u.strength)) + active_start.append(float(u.start)) + active_end.append(float(u.end)) + shared.log.debug(f'Control ControlNet-XS unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}') + elif unit_type == 'lite' and u.controlnet.model is not None: + active_process.append(u.process) + active_model.append(u.controlnet) + active_strength.append(float(u.strength)) + shared.log.debug(f'Control ControlLLite unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}') + elif unit_type == 'reference': + p.override = u.override + p.attention = u.attention + p.query_weight = float(u.query_weight) + p.adain_weight = float(u.adain_weight) + p.fidelity = u.fidelity + shared.log.debug('Control Reference unit') + else: + if u.process.processor_id is not None: + active_process.append(u.process) + shared.log.debug(f'Control process unit: i={num_units} process={u.process.processor_id}') + active_strength.append(float(u.strength)) + debug(f'Control active: process={len(active_process)} model={len(active_model)}') + return active_process, active_model, active_strength, active_start, active_end + + +def check_enabled(p, unit_type, units, active_model, active_strength, active_start, active_end): + has_models = False + selected_models: List[Union[controlnet.ControlNetModel, xs.ControlNetXSModel, t2iadapter.AdapterModel]] = None + control_conditioning = None + control_guidance_start = None + control_guidance_end = None + if unit_type == 't2i adapter' or unit_type == 'controlnet' or unit_type == 'xs' or unit_type == 'lite': + if len(active_model) == 0: + selected_models = None + elif len(active_model) == 1: + selected_models = active_model[0].model if active_model[0].model is not None else None + p.is_tile = p.is_tile or 'tile' in active_model[0].model_id.lower() + has_models = selected_models is not None + control_conditioning = active_strength[0] if len(active_strength) > 0 else 1 # strength or list[strength] + control_guidance_start = active_start[0] if len(active_start) > 0 else 0 + control_guidance_end = active_end[0] if len(active_end) > 0 else 1 + else: + selected_models = [m.model for m in active_model if m.model is not None] + has_models = len(selected_models) > 0 + control_conditioning = active_strength[0] if len(active_strength) == 1 else list(active_strength) # strength or list[strength] + control_guidance_start = active_start[0] if len(active_start) == 1 else list(active_start) + control_guidance_end = active_end[0] if len(active_end) == 1 else list(active_end) + elif unit_type == 'reference': + has_models = any(u.enabled for u in units if u.type == 'reference') + else: + pass + return has_models, selected_models, control_conditioning, control_guidance_start, control_guidance_end + + def control_set(kwargs): if kwargs: global p_extra_args # pylint: disable=global-statement @@ -83,20 +245,15 @@ def control_run(state: str = '', u.adapter.load(u.model_name, force=False) else: u.controlnet.load(u.model_name, force=False) + u.update_choices(u.model_name) if u.process is not None and u.process.override is None and u.override is not None: u.process.override = u.override - global instance, pipe, original_pipeline # pylint: disable=global-statement - t_start = time.time() + global pipe, original_pipeline # pylint: disable=global-statement debug(f'Control: type={unit_type} input={inputs} init={inits} type={input_type}') if inputs is None or (type(inputs) is list and len(inputs) == 0): inputs = [None] output_images: List[Image.Image] = [] # output images - active_process: List[processors.Processor] = [] # all active preprocessors - active_model: List[Union[controlnet.ControlNet, xs.ControlNetXS, t2iadapter.Adapter]] = [] # all active models - active_strength: List[float] = [] # strength factors for all active models - active_start: List[float] = [] # start step for all active models - active_end: List[float] = [] # end step for all active models processed_image: Image.Image = None # last processed image if mask is not None and input_type == 0: input_type = 1 # inpaint always requires control_image @@ -150,10 +307,11 @@ def control_run(state: str = '', outpath_grids=shared.opts.outdir_grids or shared.opts.outdir_control_grids, ) p.state = state + p.is_tile = False # processing.process_init(p) resize_mode_before = resize_mode_before if resize_name_before != 'None' and inputs is not None and len(inputs) > 0 else 0 - # TODO monkey-patch for modernui missing tabs.select event + # TODO modernui: monkey-patch for missing tabs.select event if selected_scale_tab_before == 0 and resize_name_before != 'None' and scale_by_before != 1 and inputs is not None and len(inputs) > 0: shared.log.debug('Control: override resize mode=before') selected_scale_tab_before = 1 @@ -224,155 +382,17 @@ def control_run(state: str = '', unit_type = unit_type.strip().lower() if unit_type is not None else '' t0 = time.time() - num_units = 0 - for u in units: - if u.type != unit_type: - continue - num_units += 1 - debug(f'Control unit: i={num_units} type={u.type} enabled={u.enabled}') - if not u.enabled: - if u.controlnet is not None and u.controlnet.model is not None: - debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.cpu}') - sd_models.move_model(u.controlnet.model, devices.cpu) - continue - if u.controlnet is not None and u.controlnet.model is not None: - debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.device}') - sd_models.move_model(u.controlnet.model, devices.device) - if unit_type == 't2i adapter' and u.adapter.model is not None: - active_process.append(u.process) - active_model.append(u.adapter) - active_strength.append(float(u.strength)) - p.adapter_conditioning_factor = u.factor - shared.log.debug(f'Control T2I-Adapter unit: i={num_units} process={u.process.processor_id} model={u.adapter.model_id} strength={u.strength} factor={u.factor}') - elif unit_type == 'controlnet' and u.controlnet.model is not None: - active_process.append(u.process) - active_model.append(u.controlnet) - active_strength.append(float(u.strength)) - active_start.append(float(u.start)) - active_end.append(float(u.end)) - p.guess_mode = u.guess - p.control_mode = u.mode - shared.log.debug(f'Control ControlNet unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end} mode={u.mode}') - elif unit_type == 'xs' and u.controlnet.model is not None: - active_process.append(u.process) - active_model.append(u.controlnet) - active_strength.append(float(u.strength)) - active_start.append(float(u.start)) - active_end.append(float(u.end)) - shared.log.debug(f'Control ControlNet-XS unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}') - elif unit_type == 'lite' and u.controlnet.model is not None: - active_process.append(u.process) - active_model.append(u.controlnet) - active_strength.append(float(u.strength)) - shared.log.debug(f'Control ControlLLite unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}') - elif unit_type == 'reference': - p.override = u.override - p.attention = u.attention - p.query_weight = float(u.query_weight) - p.adain_weight = float(u.adain_weight) - p.fidelity = u.fidelity - shared.log.debug('Control Reference unit') - else: - if u.process.processor_id is not None: - active_process.append(u.process) - shared.log.debug(f'Control process unit: i={num_units} process={u.process.processor_id}') - active_strength.append(float(u.strength)) - debug(f'Control active: process={len(active_process)} model={len(active_model)}') + + active_process, active_model, active_strength, active_start, active_end = check_active(p, unit_type, units) + has_models, selected_models, control_conditioning, control_guidance_start, control_guidance_end = check_enabled(p, unit_type, units, active_model, active_strength, active_start, active_end) processed: processing.Processed = None image_txt = '' info_txt = [] - has_models = False - selected_models: List[Union[controlnet.ControlNetModel, xs.ControlNetXSModel, t2iadapter.AdapterModel]] = None - control_conditioning = None - control_guidance_start = None - control_guidance_end = None - if unit_type == 't2i adapter' or unit_type == 'controlnet' or unit_type == 'xs' or unit_type == 'lite': - if len(active_model) == 0: - selected_models = None - elif len(active_model) == 1: - selected_models = active_model[0].model if active_model[0].model is not None else None - has_models = selected_models is not None - control_conditioning = active_strength[0] if len(active_strength) > 0 else 1 # strength or list[strength] - control_guidance_start = active_start[0] if len(active_start) > 0 else 0 - control_guidance_end = active_end[0] if len(active_end) > 0 else 1 - else: - selected_models = [m.model for m in active_model if m.model is not None] - has_models = len(selected_models) > 0 - control_conditioning = active_strength[0] if len(active_strength) == 1 else list(active_strength) # strength or list[strength] - control_guidance_start = active_start[0] if len(active_start) == 1 else list(active_start) - control_guidance_end = active_end[0] if len(active_end) == 1 else list(active_end) - elif unit_type == 'reference': - has_models = any(u.enabled for u in units if u.type == 'reference') - else: - pass - def set_pipe(): - global pipe, instance # pylint: disable=global-statement - pipe = None - if has_models: - p.ops.append('control') - p.extra_generation_params["Control mode"] = unit_type # overriden later with pretty-print - p.extra_generation_params["Control conditioning"] = control_conditioning if isinstance(control_conditioning, list) else [control_conditioning] - p.extra_generation_params['Control start'] = control_guidance_start if isinstance(control_guidance_start, list) else [control_guidance_start] - p.extra_generation_params['Control end'] = control_guidance_end if isinstance(control_guidance_end, list) else [control_guidance_end] - p.extra_generation_params["Control model"] = ';'.join([(m.model_id or '') for m in active_model if m.model is not None]) - p.extra_generation_params["Control conditioning"] = ';'.join([str(c) for c in p.extra_generation_params["Control conditioning"]]) - p.extra_generation_params['Control start'] = ';'.join([str(c) for c in p.extra_generation_params['Control start']]) - p.extra_generation_params['Control end'] = ';'.join([str(c) for c in p.extra_generation_params['Control end']]) - if unit_type == 't2i adapter' and has_models: - p.extra_generation_params["Control mode"] = 'T2I-Adapter' - p.task_args['adapter_conditioning_scale'] = control_conditioning - instance = t2iadapter.AdapterPipeline(selected_models, shared.sd_model) - pipe = instance.pipeline - if inits is not None: - shared.log.warning('Control: T2I-Adapter does not support separate init image') - elif unit_type == 'controlnet' and has_models: - p.extra_generation_params["Control mode"] = 'ControlNet' - p.task_args['controlnet_conditioning_scale'] = control_conditioning - p.task_args['control_guidance_start'] = control_guidance_start - p.task_args['control_guidance_end'] = control_guidance_end - p.task_args['guess_mode'] = p.guess_mode - instance = controlnet.ControlNetPipeline(selected_models, shared.sd_model) - pipe = instance.pipeline - elif unit_type == 'xs' and has_models: - p.extra_generation_params["Control mode"] = 'ControlNet-XS' - p.controlnet_conditioning_scale = control_conditioning - p.control_guidance_start = control_guidance_start - p.control_guidance_end = control_guidance_end - instance = xs.ControlNetXSPipeline(selected_models, shared.sd_model) - pipe = instance.pipeline - if inits is not None: - shared.log.warning('Control: ControlNet-XS does not support separate init image') - elif unit_type == 'lite' and has_models: - p.extra_generation_params["Control mode"] = 'ControlLLLite' - p.controlnet_conditioning_scale = control_conditioning - instance = lite.ControlLLitePipeline(shared.sd_model) - pipe = instance.pipeline - if inits is not None: - shared.log.warning('Control: ControlLLLite does not support separate init image') - elif unit_type == 'reference' and has_models: - p.extra_generation_params["Control mode"] = 'Reference' - p.extra_generation_params["Control attention"] = p.attention - p.task_args['reference_attn'] = 'Attention' in p.attention - p.task_args['reference_adain'] = 'Adain' in p.attention - p.task_args['attention_auto_machine_weight'] = p.query_weight - p.task_args['gn_auto_machine_weight'] = p.adain_weight - p.task_args['style_fidelity'] = p.fidelity - instance = reference.ReferencePipeline(shared.sd_model) - pipe = instance.pipeline - if inits is not None: - shared.log.warning('Control: ControlNet-XS does not support separate init image') - else: # run in txt2img/img2img mode - if len(active_strength) > 0: - p.strength = active_strength[0] - pipe = shared.sd_model - instance = None - debug(f'Control: run type={unit_type} models={has_models} pipe={pipe.__class__.__name__ if pipe is not None else None}') - return pipe - - - pipe = set_pipe() + p.is_tile = p.is_tile and has_models + + pipe = set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits) debug(f'Control pipeline: class={pipe.__class__.__name__} args={vars(p)}') t1, t2, t3 = time.time(), 0, 0 status = True @@ -395,6 +415,8 @@ def set_pipe(): else: original_pipeline = None + possible = sd_models.get_call(pipe).keys() + try: with devices.inference_context(): if isinstance(inputs, str): # only video, the rest is a list @@ -424,7 +446,7 @@ def set_pipe(): while status: if pipe is None: # pipe may have been reset externally - pipe = set_pipe() + pipe = set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits) debug(f'Control pipeline reinit: class={pipe.__class__.__name__}') processed_image = None if frame is not None: @@ -564,19 +586,29 @@ def set_pipe(): return [], '', '', 'Reference mode without image' elif unit_type == 'controlnet' and has_models: if input_type == 0: # Control only - if shared.sd_model_type in ['f1', 'sd3'] and 'control_image' not in p.task_args: - p.task_args['control_image'] = p.init_images # some controlnets mandate this + if 'control_image' in possible: + p.task_args['control_image'] = [p.init_images] if isinstance(p.init_images, Image.Image) else p.init_images + elif 'image' in possible: + p.task_args['image'] = [p.init_images] if isinstance(p.init_images, Image.Image) else p.init_images + if 'control_mode' in possible: + p.task_args['control_mode'] = getattr(p, 'control_mode', None) + if 'strength' in possible: p.task_args['strength'] = p.denoising_strength + p.init_images = None elif input_type == 1: # Init image same as control - p.task_args['control_image'] = p.init_images # switch image and control_image - p.task_args['strength'] = p.denoising_strength + if 'control_image' in possible: + p.task_args['control_image'] = p.init_images # switch image and control_image + if 'strength' in possible: + p.task_args['strength'] = p.denoising_strength p.init_images = [p.override or input_image] * len(active_model) elif input_type == 2: # Separate init image if init_image is None: shared.log.warning('Control: separate init image not provided') init_image = input_image - p.task_args['control_image'] = p.init_images # switch image and control_image - p.task_args['strength'] = p.denoising_strength + if 'control_image' in possible: + p.task_args['control_image'] = p.init_images # switch image and control_image + if 'strength' in possible: + p.task_args['strength'] = p.denoising_strength p.init_images = [init_image] * len(active_model) if is_generator: @@ -609,26 +641,31 @@ def set_pipe(): p.task_args['strength'] = denoising_strength p.image_mask = mask shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.INPAINTING) # only controlnet supports inpaint - elif 'control_image' in p.task_args: + if hasattr(p, 'init_images') and p.init_images is not None: shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) # only controlnet supports img2img else: shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE) - if hasattr(p, 'init_images') and p.init_images is not None: + if hasattr(p, 'init_images') and p.init_images is not None and 'image' in possible: p.task_args['image'] = p.init_images # need to set explicitly for txt2img del p.init_images if unit_type == 'lite': p.init_image = [input_image] instance.apply(selected_models, processed_image, control_conditioning) - if p.control_mode is not None: - p.task_args['control_mode'] = p.control_mode + if getattr(p, 'control_mode', None) is not None: + p.task_args['control_mode'] = getattr(p, 'control_mode', None) if hasattr(p, 'init_images') and p.init_images is None: # delete empty del p.init_images # final check if has_models: - if unit_type in ['controlnet', 't2i adapter', 'lite', 'xs'] and p.task_args.get('image', None) is None and getattr(p, 'init_images', None) is None: + if unit_type in ['controlnet', 't2i adapter', 'lite', 'xs'] \ + and p.task_args.get('image', None) is None \ + and p.task_args.get('control_image', None) is None \ + and getattr(p, 'init_images', None) is None \ + and getattr(p, 'image', None) is None: if is_generator: - yield terminate(f'Mode={p.extra_generation_params.get("Control mode", None)} input image is none') + shared.log.debug(f'Control args: {p.task_args}') + yield terminate(f'Mode={p.extra_generation_params.get("Control type", None)} input image is none') return [], '', '', 'Error: Input image is none' # resize mask @@ -658,12 +695,19 @@ def set_pipe(): script_runner.initialize_scripts(False) p.script_args = script.init_default_script_args(script_runner) - processed = p.scripts.run(p, *p.script_args) + # actual processing + if p.is_tile: + processed: processing.Processed = tile.run_tiling(p, input_image) + if processed is None and p.scripts is not None: + processed = p.scripts.run(p, *p.script_args) if processed is None: processed: processing.Processed = processing.process_images(p) # run actual pipeline else: script_run = True - processed = p.scripts.after(p, processed, *p.script_args) + + # postprocessing + if p.scripts is not None: + processed = p.scripts.after(p, processed, *p.script_args) output = None if processed is not None: output = processed.images @@ -717,14 +761,11 @@ def set_pipe(): shared.log.error(f'Control pipeline failed: type={unit_type} units={len(active_model)} error={e}') errors.display(e, 'Control') - t_end = time.time() - if len(output_images) == 0: output_images = None image_txt = '| Images None' else: - image_str = [f'{image.width}x{image.height}' for image in output_images] - image_txt = f'| Time {t_end-t_start:.2f}s | Images {len(output_images)} | Size {" ".join(image_str)}' + image_txt = '' p.init_images = output_images # may be used for hires if video_type != 'None' and isinstance(output_images, list): @@ -738,10 +779,9 @@ def set_pipe(): restore_pipeline() debug(f'Ready: {image_txt}') - html_txt = f'

Ready {image_txt}

' + html_txt = f'

Ready {image_txt}

' if image_txt != '' else '' if len(info_txt) > 0: html_txt = html_txt + infotext_to_html(info_txt[0]) if is_generator: yield (output_images, blended_image, html_txt, output_filename) - else: - return (output_images, blended_image, html_txt, output_filename) + return (output_images, blended_image, html_txt, output_filename) diff --git a/modules/control/tile.py b/modules/control/tile.py new file mode 100644 index 000000000..de9df1131 --- /dev/null +++ b/modules/control/tile.py @@ -0,0 +1,73 @@ +import time +from PIL import Image +from modules import shared, processing, images, sd_models + + +def get_tile(image: Image.Image, x: int, y: int, sx: int, sy: int) -> Image.Image: + return image.crop(( + (x + 0) * image.width // sx, + (y + 0) * image.height // sy, + (x + 1) * image.width // sx, + (y + 1) * image.height // sy + )) + + +def set_tile(image: Image.Image, x: int, y: int, tiled: Image.Image): + image.paste(tiled, (x * tiled.width, y * tiled.height)) + return image + + +def run_tiling(p: processing.StableDiffusionProcessing, input_image: Image.Image) -> processing.Processed: + t0 = time.time() + # prepare images + sx, sy = p.control_tile.split('x') + sx = int(sx) + sy = int(sy) + if sx <= 0 or sy <= 0: + raise ValueError('Control Tile: invalid tile size') + control_image = p.task_args.get('control_image', None) or p.task_args.get('image', None) + control_upscaled = None + if isinstance(control_image, list) and len(control_image) > 0: + w, h = 8 * int(sx * control_image[0].width) // 8, 8 * int(sy * control_image[0].height) // 8 + control_upscaled = images.resize_image(resize_mode=1 if sx==sy else 5, im=control_image[0], width=w, height=h, context='add with forward') + init_image = p.override or input_image + init_upscaled = None + if init_image is not None: + w, h = 8 * int(sx * init_image.width) // 8, 8 * int(sy * init_image.height) // 8 + init_upscaled = images.resize_image(resize_mode=1 if sx==sy else 5, im=init_image, width=w, height=h, context='add with forward') + t1 = time.time() + shared.log.debug(f'Control Tile: scale={sx}x{sy} resize={"fixed" if sx==sy else "context"} control={control_upscaled} init={init_upscaled} time={t1-t0:.3f}') + + # stop processing from restoring pipeline on each iteration + orig_restore_pipeline = getattr(shared.sd_model, 'restore_pipeline', None) + shared.sd_model.restore_pipeline = None + + # run tiling + for x in range(sx): + for y in range(sy): + shared.log.info(f'Control Tile: tile={x+1}-{sx}/{y+1}-{sy} target={control_upscaled}') + shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) + p.init_images = None + p.task_args['control_mode'] = p.control_mode + p.task_args['strength'] = p.denoising_strength + if init_upscaled is not None: + p.task_args['image'] = [get_tile(init_upscaled, x, y, sx, sy)] + if control_upscaled is not None: + p.task_args['control_image'] = [get_tile(control_upscaled, x, y, sx, sy)] + processed: processing.Processed = processing.process_images(p) # run actual pipeline + if processed is None or len(processed.images) == 0: + continue + control_upscaled = set_tile(control_upscaled, x, y, processed.images[0]) + + # post-process + p.width = control_upscaled.width + p.height = control_upscaled.height + processed.images = [control_upscaled] + processed.info = processed.infotext(p, 0) + processed.infotexts = [processed.info] + shared.sd_model.restore_pipeline = orig_restore_pipeline + if hasattr(shared.sd_model, 'restore_pipeline') and shared.sd_model.restore_pipeline is not None: + shared.sd_model.restore_pipeline() + t2 = time.time() + shared.log.debug(f'Control Tile: image={control_upscaled} time={t2-t0:.3f}') + return processed diff --git a/modules/control/unit.py b/modules/control/unit.py index 7dc5528a6..eeb729740 100644 --- a/modules/control/unit.py +++ b/modules/control/unit.py @@ -16,6 +16,22 @@ class Unit(): # mashup of gradio controls and mapping to actual implementation classes + def update_choices(self, model_id=None): + name = model_id or self.model_name + if name == 'InstantX Union': + self.choices = ['canny', 'tile', 'depth', 'blur', 'pose', 'gray', 'lq'] + elif name == 'Shakker-Labs Union': + self.choices = ['canny', 'tile', 'depth', 'blur', 'pose', 'gray', 'lq'] + elif name == 'Xinsir Union XL': + self.choices = ['openpose', 'depth', 'scribble', 'canny', 'normal'] + elif name == 'Xinsir ProMax XL': + self.choices = ['openpose', 'depth', 'scribble', 'canny', 'normal', 'segment', 'tile', 'repaint'] + else: + self.choices = ['default'] + + def __str__(self): + return f'Unit: type={self.type} enabled={self.enabled} strength={self.strength} start={self.start} end={self.end} mode={self.mode} tile={self.tile}' + def __init__(self, # values index: int = None, @@ -38,6 +54,7 @@ def __init__(self, control_start = None, control_end = None, control_mode = None, + control_tile = None, result_txt = None, extra_controls: list = [], ): @@ -70,6 +87,10 @@ def __init__(self, self.fidelity = 0.5 self.query_weight = 1.0 self.adain_weight = 1.0 + # control mode + self.choices = ['default'] + # control tile + self.tile = '1x1' def reset(): if self.process is not None: @@ -92,10 +113,16 @@ def control_change(start, end): self.end = max(start, end) def control_mode_change(mode): - self.mode = mode - 1 if mode > 0 else None + self.mode = self.choices.index(mode) if mode is not None and mode in self.choices else 0 + + def control_tile_change(tile): + self.tile = tile - def control_mode_show(model_id): - return gr.update(visible='union' in model_id.lower()) + def control_choices(model_id): + self.update_choices(model_id) + mode_visible = 'union' in model_id.lower() or 'promax' in model_id.lower() + tile_visible = 'union' in model_id.lower() or 'promax' in model_id.lower() or 'tile' in model_id.lower() + return [gr.update(visible=mode_visible, choices=self.choices), gr.update(visible=tile_visible)] def adapter_extra(c1): self.factor = c1 @@ -172,7 +199,7 @@ def set_image(image): else: self.controls.append(model_id) model_id.change(fn=self.controlnet.load, inputs=[model_id], outputs=[result_txt], show_progress=True) - model_id.change(fn=control_mode_show, inputs=[model_id], outputs=[control_mode], show_progress=False) + model_id.change(fn=control_choices, inputs=[model_id], outputs=[control_mode, control_tile], show_progress=False) if extra_controls is not None and len(extra_controls) > 0: extra_controls[0].change(fn=controlnet_extra, inputs=extra_controls) elif self.type == 'xs': @@ -231,3 +258,6 @@ def set_image(image): if control_mode is not None: self.controls.append(control_mode) control_mode.change(fn=control_mode_change, inputs=[control_mode]) + if control_tile is not None: + self.controls.append(control_tile) + control_tile.change(fn=control_tile_change, inputs=[control_tile]) diff --git a/modules/control/units/controlnet.py b/modules/control/units/controlnet.py index 20b99412a..c887aca8f 100644 --- a/modules/control/units/controlnet.py +++ b/modules/control/units/controlnet.py @@ -5,6 +5,7 @@ from modules.control.units import detect from modules.shared import log, opts, listdir from modules import errors, sd_models, devices, model_quant +from modules.processing import StableDiffusionProcessingControl what = 'ControlNet' @@ -51,17 +52,20 @@ 'Depth Mid XL': 'diffusers/controlnet-depth-sdxl-1.0-mid', 'OpenPose XL': 'thibaud/controlnet-openpose-sdxl-1.0/bin', 'Xinsir Union XL': 'xinsir/controlnet-union-sdxl-1.0', + 'Xinsir ProMax XL': 'brad-twinkl/controlnet-union-sdxl-1.0-promax', 'Xinsir OpenPose XL': 'xinsir/controlnet-openpose-sdxl-1.0', 'Xinsir Canny XL': 'xinsir/controlnet-canny-sdxl-1.0', 'Xinsir Depth XL': 'xinsir/controlnet-depth-sdxl-1.0', 'Xinsir Scribble XL': 'xinsir/controlnet-scribble-sdxl-1.0', 'Xinsir Anime Painter XL': 'xinsir/anime-painter', + 'Xinsir Tile XL': 'xinsir/controlnet-tile-sdxl-1.0', 'NoobAI Canny XL': 'Eugeoter/noob-sdxl-controlnet-canny', 'NoobAI Lineart Anime XL': 'Eugeoter/noob-sdxl-controlnet-lineart_anime', 'NoobAI Depth XL': 'Eugeoter/noob-sdxl-controlnet-depth', 'NoobAI Normal XL': 'Eugeoter/noob-sdxl-controlnet-normal', 'NoobAI SoftEdge XL': 'Eugeoter/noob-sdxl-controlnet-softedge_hed', 'NoobAI OpenPose XL': 'einar77/noob-openpose', + 'TTPlanet Tile Realistic XL': 'Yakonrus/SDXL_Controlnet_Tile_Realistic_v2', # 'StabilityAI Canny R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-canny-rank128.safetensors', # 'StabilityAI Depth R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-depth-rank128.safetensors', # 'StabilityAI Recolor R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-recolor-rank128.safetensors', @@ -75,6 +79,8 @@ "InstantX Union": 'InstantX/FLUX.1-dev-Controlnet-Union', "InstantX Canny": 'InstantX/FLUX.1-dev-Controlnet-Canny', "JasperAI Depth": 'jasperai/Flux.1-dev-Controlnet-Depth', + "BlackForrestLabs Canny LoRA": '/huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora/flux1-canny-dev-lora.safetensors', + "BlackForrestLabs Depth LoRA": '/huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora/flux1-depth-dev-lora.safetensors', "JasperAI Surface Normals": 'jasperai/Flux.1-dev-Controlnet-Surface-Normals', "JasperAI Upscaler": 'jasperai/Flux.1-dev-Controlnet-Upscaler', "Shakker-Labs Union": 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro', @@ -85,6 +91,9 @@ "XLabs-AI HED": 'XLabs-AI/flux-controlnet-hed-diffusers' } predefined_sd3 = { + "StabilityAI Canny": 'diffusers-internal-dev/sd35-controlnet-canny-8b', + "StabilityAI Depth": 'diffusers-internal-dev/sd35-controlnet-depth-8b', + "StabilityAI Blur": 'diffusers-internal-dev/sd35-controlnet-blur-8b', "InstantX Canny": 'InstantX/SD3-Controlnet-Canny', "InstantX Pose": 'InstantX/SD3-Controlnet-Pose', "InstantX Depth": 'InstantX/SD3-Controlnet-Depth', @@ -92,6 +101,14 @@ "Alimama Inpainting": 'alimama-creative/SD3-Controlnet-Inpainting', "Alimama SoftEdge": 'alimama-creative/SD3-Controlnet-Softedge', } +variants = { + 'NoobAI Canny XL': 'fp16', + 'NoobAI Lineart Anime XL': 'fp16', + 'NoobAI Depth XL': 'fp16', + 'NoobAI Normal XL': 'fp16', + 'NoobAI SoftEdge XL': 'fp16', + 'TTPlanet Tile Realistic XL': 'fp16', +} models = {} all_models = {} all_models.update(predefined_sd15) @@ -159,26 +176,35 @@ def reset(self): self.model = None self.model_id = None - def get_class(self): - import modules.shared - if modules.shared.sd_model_type == 'sd': + def get_class(self, model_id:str=''): + from modules import shared + if shared.sd_model_type == 'none': + _load = shared.sd_model # trigger a load + if shared.sd_model_type == 'sd': from diffusers import ControlNetModel as cls # pylint: disable=reimported config = 'lllyasviel/control_v11p_sd15_canny' - elif modules.shared.sd_model_type == 'sdxl': - from diffusers import ControlNetModel as cls # pylint: disable=reimported # sdxl shares same model class - config = 'Eugeoter/noob-sdxl-controlnet-canny' - elif modules.shared.sd_model_type == 'f1': + elif shared.sd_model_type == 'sdxl': + if 'union' in model_id.lower(): + from diffusers import ControlNetUnionModel as cls + config = 'xinsir/controlnet-union-sdxl-1.0' + elif 'promax' in model_id.lower(): + from diffusers import ControlNetUnionModel as cls + config = 'brad-twinkl/controlnet-union-sdxl-1.0-promax' + else: + from diffusers import ControlNetModel as cls # pylint: disable=reimported # sdxl shares same model class + config = 'Eugeoter/noob-sdxl-controlnet-canny' + elif shared.sd_model_type == 'f1': from diffusers import FluxControlNetModel as cls config = 'InstantX/FLUX.1-dev-Controlnet-Union' - elif modules.shared.sd_model_type == 'sd3': + elif shared.sd_model_type == 'sd3': from diffusers import SD3ControlNetModel as cls config = 'InstantX/SD3-Controlnet-Canny' else: - log.error(f'Control {what}: type={modules.shared.sd_model_type} unsupported model') + log.error(f'Control {what}: type={shared.sd_model_type} unsupported model') return None, None return cls, config - def load_safetensors(self, model_path): + def load_safetensors(self, model_id, model_path): name = os.path.splitext(model_path)[0] config_path = None if not os.path.exists(model_path): @@ -203,7 +229,7 @@ def load_safetensors(self, model_path): config_path = f'{name}.json' if config_path is not None: self.load_config['original_config_file '] = config_path - cls, config = self.get_class() + cls, config = self.get_class(model_id) if cls is None: log.error(f'Control {what} model load failed: unknown base model') else: @@ -225,23 +251,26 @@ def load(self, model_id: str = None, force: bool = True) -> str: if model_path is None: log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') return + if 'lora' in model_id.lower(): + self.model = model_path + return if model_id == self.model_id and not force: log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') return log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"') + cls, _config = self.get_class(model_id) if model_path.endswith('.safetensors'): - self.load_safetensors(model_path) + self.load_safetensors(model_id, model_path) else: kwargs = {} if '/bin' in model_path: model_path = model_path.replace('/bin', '') self.load_config['use_safetensors'] = False - cls, _config = self.get_class() if cls is None: log.error(f'Control {what} model load failed: id="{model_id}" unknown base model') return - if 'Eugeoter' in model_path: - kwargs['variant'] = 'fp16' + if variants.get(model_id, None) is not None: + kwargs['variant'] = variants[model_id] self.model = cls.from_pretrained(model_path, **self.load_config, **kwargs) if self.model is None: return @@ -268,7 +297,7 @@ def load(self, model_id: str = None, force: bool = True) -> str: self.model.to(self.device) t1 = time.time() self.model_id = model_id - log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') + log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" cls={cls.__name__} time={t1-t0:.2f}') return f'{what} loaded model: {model_id}' except Exception as e: log.error(f'Control {what} model load failed: id="{model_id}" error={e}') @@ -281,16 +310,27 @@ def __init__(self, controlnet: Union[ControlNetModel, list[ControlNetModel]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline, FluxPipeline, StableDiffusion3Pipeline], dtype = None, + p: StableDiffusionProcessingControl = None, # pylint: disable=unused-argument ): t0 = time.time() self.orig_pipeline = pipeline self.pipeline = None + + controlnets = controlnet if isinstance(controlnet, list) else [controlnet] + loras = [cn for cn in controlnets if isinstance(cn, str)] + controlnets = [cn for cn in controlnets if not isinstance(cn, str)] + if pipeline is None: log.error('Control model pipeline: model not loaded') return - elif detect.is_sdxl(pipeline): - from diffusers import StableDiffusionXLControlNetPipeline - self.pipeline = StableDiffusionXLControlNetPipeline( + elif detect.is_sdxl(pipeline) and len(controlnets) > 0: + from diffusers import StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetUnionPipeline + if controlnet.__class__.__name__ == 'ControlNetUnionModel': + cls = StableDiffusionXLControlNetUnionPipeline + controlnets = controlnets[0] # using only first one + else: + cls = StableDiffusionXLControlNetPipeline + self.pipeline = cls( vae=pipeline.vae, text_encoder=pipeline.text_encoder, text_encoder_2=pipeline.text_encoder_2, @@ -299,9 +339,9 @@ def __init__(self, unet=pipeline.unet, scheduler=pipeline.scheduler, feature_extractor=getattr(pipeline, 'feature_extractor', None), - controlnet=controlnet, # can be a list + controlnet=controlnets, # can be a list ) - elif detect.is_sd15(pipeline): + elif detect.is_sd15(pipeline) and len(controlnets) > 0: from diffusers import StableDiffusionControlNetPipeline self.pipeline = StableDiffusionControlNetPipeline( vae=pipeline.vae, @@ -312,10 +352,10 @@ def __init__(self, feature_extractor=getattr(pipeline, 'feature_extractor', None), requires_safety_checker=False, safety_checker=None, - controlnet=controlnet, # can be a list + controlnet=controlnets, # can be a list ) sd_models.move_model(self.pipeline, pipeline.device) - elif detect.is_f1(pipeline): + elif detect.is_f1(pipeline) and len(controlnets) > 0: from diffusers import FluxControlNetPipeline self.pipeline = FluxControlNetPipeline( vae=pipeline.vae.to(devices.device), @@ -325,9 +365,9 @@ def __init__(self, tokenizer_2=pipeline.tokenizer_2, transformer=pipeline.transformer, scheduler=pipeline.scheduler, - controlnet=controlnet, # can be a list + controlnet=controlnets, # can be a list ) - elif detect.is_sd3(pipeline): + elif detect.is_sd3(pipeline) and len(controlnets) > 0: from diffusers import StableDiffusion3ControlNetPipeline self.pipeline = StableDiffusion3ControlNetPipeline( vae=pipeline.vae, @@ -339,8 +379,18 @@ def __init__(self, tokenizer_3=pipeline.tokenizer_3, transformer=pipeline.transformer, scheduler=pipeline.scheduler, - controlnet=controlnet, # can be a list + controlnet=controlnets, # can be a list ) + elif len(loras) > 0: + self.pipeline = pipeline + for lora in loras: + log.debug(f'Control {what} pipeline: lora="{lora}"') + lora = lora.replace('/huggingface.co/', '') + self.pipeline.load_lora_weights(lora) + """ + if p is not None: + p.prompt += f'' + """ else: log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type') return @@ -350,6 +400,7 @@ def __init__(self, return if dtype is not None: self.pipeline = self.pipeline.to(dtype) + if opts.diffusers_offload_mode == 'none': sd_models.move_model(self.pipeline, devices.device) from modules.sd_models import set_diffuser_offload @@ -359,5 +410,6 @@ def __init__(self, log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') def restore(self): + self.pipeline.unload_lora_weights() self.pipeline = None return self.orig_pipeline diff --git a/modules/ctrlx/__init__.py b/modules/ctrlx/__init__.py index 87ff52f6c..07d06aeac 100644 --- a/modules/ctrlx/__init__.py +++ b/modules/ctrlx/__init__.py @@ -136,7 +136,7 @@ def appearance_guidance_scale(self): @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]] = None, # TODO: Support prompt_2 and negative_prompt_2 + prompt: Union[str, List[str]] = None, structure_prompt: Optional[Union[str, List[str]]] = None, appearance_prompt: Optional[Union[str, List[str]]] = None, structure_image: Optional[PipelineImageInput] = None, @@ -180,7 +180,6 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): - # TODO: Add function argument documentation callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) @@ -205,7 +204,7 @@ def __call__( target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct - self.check_inputs( # TODO: Custom check_inputs for our method + self.check_inputs( prompt, None, # prompt_2 height, @@ -425,7 +424,7 @@ def denoising_value_valid(dnv): # 7.2 Optionally get guidance scale embedding timestep_cond = None - if self.unet.config.time_cond_proj_dim is not None: # TODO: Make guidance scale embedding work with batch_order + if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim @@ -457,7 +456,6 @@ def denoising_value_valid(dnv): register_attr(self, t=t.item(), do_control=True, batch_order=batch_order) - # TODO: For now, assume we are doing classifier-free guidance, support no CF-guidance later latent_model_input = self.scheduler.scale_model_input(latents, t) structure_latent_model_input = self.scheduler.scale_model_input(structure_latents, t) appearance_latent_model_input = self.scheduler.scale_model_input(appearance_latents, t) @@ -563,7 +561,7 @@ def denoising_value_valid(dnv): # Self-recurrence for _ in range(self_recurrence_schedule[i]): if hasattr(self.scheduler, "_step_index"): # For fancier schedulers - self.scheduler._step_index -= 1 # TODO: Does this actually work? + self.scheduler._step_index -= 1 t_prev = 0 if i + 1 >= num_inference_steps else timesteps[i + 1] latents = noise_t2t(self.scheduler, t_prev, t, latents) diff --git a/modules/devices.py b/modules/devices.py index 56ac50091..949fab4aa 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -186,53 +186,68 @@ def get_device_for(task): # pylint: disable=unused-argument return get_optimal_device() -def torch_gc(force=False, fast=False): +def torch_gc(force:bool=False, fast:bool=False, reason:str=None): + def get_stats(): + mem_dict = memstats.memory_stats() + gpu_dict = mem_dict.get('gpu', {}) + ram_dict = mem_dict.get('ram', {}) + oom = gpu_dict.get('oom', 0) + ram = ram_dict.get('used', 0) + if backend == "directml": + gpu = torch.cuda.memory_allocated() / (1 << 30) + else: + gpu = gpu_dict.get('used', 0) + used_gpu = round(100 * gpu / gpu_dict.get('total', 1)) if gpu_dict.get('total', 1) > 1 else 0 + used_ram = round(100 * ram / ram_dict.get('total', 1)) if ram_dict.get('total', 1) > 1 else 0 + return gpu, used_gpu, ram, used_ram, oom + + global previous_oom # pylint: disable=global-statement import gc from modules import timer, memstats from modules.shared import cmd_opts + t0 = time.time() - mem = memstats.memory_stats() - gpu = mem.get('gpu', {}) - ram = mem.get('ram', {}) - oom = gpu.get('oom', 0) - if backend == "directml": - used_gpu = round(100 * torch.cuda.memory_allocated() / (1 << 30) / gpu.get('total', 1)) if gpu.get('total', 1) > 1 else 0 - else: - used_gpu = round(100 * gpu.get('used', 0) / gpu.get('total', 1)) if gpu.get('total', 1) > 1 else 0 - used_ram = round(100 * ram.get('used', 0) / ram.get('total', 1)) if ram.get('total', 1) > 1 else 0 - global previous_oom # pylint: disable=global-statement + gpu, used_gpu, ram, _used_ram, oom = get_stats() threshold = 0 if (cmd_opts.lowvram and not cmd_opts.use_zluda) else opts.torch_gc_threshold collected = 0 - if force or threshold == 0 or used_gpu >= threshold or used_ram >= threshold: + if reason is None and force: + reason='force' + if threshold == 0 or used_gpu >= threshold: force = True + if reason is None: + reason = 'threshold' if oom > previous_oom: previous_oom = oom - log.warning(f'Torch GPU out-of-memory error: {mem}') + log.warning(f'Torch GPU out-of-memory error: {memstats.memory_stats()}') force = True + if reason is None: + reason = 'oom' if force: # actual gc collected = gc.collect() if not fast else 0 # python gc if cuda_ok: try: with torch.cuda.device(get_cuda_device_string()): + torch.cuda.synchronize() torch.cuda.empty_cache() # cuda gc torch.cuda.ipc_collect() except Exception: pass + else: + return gpu, ram t1 = time.time() - if 'gc' not in timer.process.records: - timer.process.records['gc'] = 0 - timer.process.records['gc'] += t1 - t0 - if not force or collected == 0: - return - mem = memstats.memory_stats() - saved = round(gpu.get('used', 0) - mem.get('gpu', {}).get('used', 0), 2) - before = { 'gpu': gpu.get('used', 0), 'ram': ram.get('used', 0) } - after = { 'gpu': mem.get('gpu', {}).get('used', 0), 'ram': mem.get('ram', {}).get('used', 0), 'retries': mem.get('retries', 0), 'oom': mem.get('oom', 0) } - utilization = { 'gpu': used_gpu, 'ram': used_ram, 'threshold': threshold } - results = { 'collected': collected, 'saved': saved } + timer.process.add('gc', t1 - t0) + if fast: + return gpu, ram + + new_gpu, new_used_gpu, new_ram, new_used_ram, oom = get_stats() + before = { 'gpu': gpu, 'ram': ram } + after = { 'gpu': new_gpu, 'ram': new_ram, 'oom': oom } + utilization = { 'gpu': new_used_gpu, 'ram': new_used_ram } + results = { 'gpu': round(gpu - new_gpu, 2), 'py': collected } fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access - log.debug(f'GC: utilization={utilization} gc={results} before={before} after={after} device={torch.device(get_optimal_device_name())} fn={fn} time={round(t1 - t0, 2)}') # pylint: disable=protected-access + log.debug(f'GC: current={after} prev={before} load={utilization} gc={results} fn={fn} why={reason} time={t1-t0:.2f}') + return new_gpu, new_ram def set_cuda_sync_mode(mode): @@ -471,7 +486,7 @@ def set_cuda_params(): device_name = get_raw_openvino_device() else: device_name = torch.device(get_optimal_device_name()) - log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} vae={dtype_vae} unet={dtype_unet} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upscast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} test-fp16={fp16_ok} test-bf16={bf16_ok} optimization="{opts.cross_attention_optimization}"') + log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} vae={dtype_vae} unet={dtype_unet} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} test-fp16={fp16_ok} test-bf16={bf16_ok} optimization="{opts.cross_attention_optimization}"') def cond_cast_unet(tensor): diff --git a/modules/devices_mac.py b/modules/devices_mac.py index fe7c80f31..2ddc1fc22 100644 --- a/modules/devices_mac.py +++ b/modules/devices_mac.py @@ -24,7 +24,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): # pylint: disable=redefined output_dtype = kwargs.get('dtype', input.dtype) if output_dtype == torch.int64: return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) - elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): + elif output_dtype == torch.bool or (cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16)): return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) return cumsum_func(input, *args, **kwargs) @@ -42,7 +42,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): # pylint: disable=redefined # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), - lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) + lambda _, self, *args, **kwargs: self.device.type != 'mps' and ((args and isinstance(args[0], torch.device) and args[0].type == 'mps') or (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))) # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') diff --git a/modules/errors.py b/modules/errors.py index 527884cf1..6302057d7 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -36,7 +36,7 @@ def print_error_explanation(message): log.error(line) -def display(e: Exception, task, suppress=[]): +def display(e: Exception, task: str, suppress=[]): log.error(f"{task or 'error'}: {type(e).__name__}") console.print_exception(show_locals=False, max_frames=16, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width) @@ -48,7 +48,7 @@ def display_once(e: Exception, task): already_displayed[task] = 1 -def run(code, task): +def run(code, task: str): try: code() except Exception as e: @@ -59,14 +59,14 @@ def exception(suppress=[]): console.print_exception(show_locals=False, max_frames=16, extra_lines=2, suppress=suppress, theme="ansi_dark", word_wrap=False, width=min([console.width, 200])) -def profile(profiler, msg: str, n: int = 5): +def profile(profiler, msg: str, n: int = 16): profiler.disable() import io import pstats stream = io.StringIO() # pylint: disable=abstract-class-instantiated p = pstats.Stats(profiler, stream=stream) p.sort_stats(pstats.SortKey.CUMULATIVE) - p.print_stats(100) + p.print_stats(200) # p.print_title() # p.print_call_heading(10, 'time') # p.print_callees(10) @@ -81,6 +81,7 @@ def profile(profiler, msg: str, n: int = 5): and '_lsprof' not in x and '/profiler' not in x and 'rich' not in x + and 'profile_torch' not in x and x.strip() != '' ] txt = '\n'.join(lines[:min(n, len(lines))]) diff --git a/modules/extensions.py b/modules/extensions.py index 5a8a53d29..ccd92dbf0 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -154,4 +154,4 @@ def list_extensions(): for dirname, path, is_builtin in extension_paths: extension = Extension(name=dirname, path=path, enabled=dirname not in disabled_extensions, is_builtin=is_builtin) extensions.append(extension) - shared.log.info(f'Disabled extensions: {[e.name for e in extensions if not e.enabled]}') + shared.log.debug(f'Disabled extensions: {[e.name for e in extensions if not e.enabled]}') diff --git a/modules/extra_networks.py b/modules/extra_networks.py index b464bd349..fe141cca1 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -1,6 +1,7 @@ import re +import inspect from collections import defaultdict -from modules import errors, shared, devices +from modules import errors, shared extra_network_registry = {} @@ -15,10 +16,14 @@ def register_extra_network(extra_network): def register_default_extra_networks(): - from modules.ui_extra_networks_hypernet import ExtraNetworkHypernet - register_extra_network(ExtraNetworkHypernet()) from modules.ui_extra_networks_styles import ExtraNetworkStyles register_extra_network(ExtraNetworkStyles()) + if shared.native: + from modules.lora.networks import extra_network_lora + register_extra_network(extra_network_lora) + if shared.opts.hypernetwork_enabled: + from modules.ui_extra_networks_hypernet import ExtraNetworkHypernet + register_extra_network(ExtraNetworkHypernet()) class ExtraNetworkParams: @@ -70,9 +75,12 @@ def is_stepwise(en_obj): return any([len(str(x).split("@")) > 1 for x in all_args]) # noqa C419 # pylint: disable=use-a-generator -def activate(p, extra_network_data, step=0): +def activate(p, extra_network_data=None, step=0, include=[], exclude=[]): """call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list""" - if extra_network_data is None: + if p.disable_extra_networks: + return + extra_network_data = extra_network_data or p.network_data + if extra_network_data is None or len(extra_network_data) == 0: return stepwise = False for extra_network_args in extra_network_data.values(): @@ -82,35 +90,42 @@ def activate(p, extra_network_data, step=0): shared.log.warning("Composable LoRA not compatible with 'lora_force_diffusers'") stepwise = False shared.opts.data['lora_functional'] = stepwise or functional - with devices.autocast(): - for extra_network_name, extra_network_args in extra_network_data.items(): - extra_network = extra_network_registry.get(extra_network_name, None) - if extra_network is None: - errors.log.warning(f"Skipping unknown extra network: {extra_network_name}") - continue - try: + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + errors.log.warning(f"Skipping unknown extra network: {extra_network_name}") + continue + try: + signature = list(inspect.signature(extra_network.activate).parameters) + if 'include' in signature and 'exclude' in signature: + extra_network.activate(p, extra_network_args, step=step, include=include, exclude=exclude) + else: extra_network.activate(p, extra_network_args, step=step) - except Exception as e: - errors.display(e, f"Activating network: type={extra_network_name} args:{extra_network_args}") - - for extra_network_name, extra_network in extra_network_registry.items(): - args = extra_network_data.get(extra_network_name, None) - if args is not None: - continue - try: - extra_network.activate(p, []) - except Exception as e: - errors.display(e, f"Activating network: type={extra_network_name}") - - p.extra_network_data = extra_network_data + except Exception as e: + errors.display(e, f"Activating network: type={extra_network_name} args:{extra_network_args}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + try: + extra_network.activate(p, []) + except Exception as e: + errors.display(e, f"Activating network: type={extra_network_name}") + + p.network_data = extra_network_data if stepwise: p.stepwise_lora = True shared.opts.data['lora_functional'] = functional -def deactivate(p, extra_network_data): +def deactivate(p, extra_network_data=None): """call deactivate for extra networks in extra_network_data in specified order, then call deactivate for all remaining registered networks""" - if extra_network_data is None: + if p.disable_extra_networks: + return + extra_network_data = extra_network_data or p.network_data + if extra_network_data is None or len(extra_network_data) == 0: return for extra_network_name in extra_network_data: extra_network = extra_network_registry.get(extra_network_name, None) diff --git a/modules/face/faceid.py b/modules/face/faceid.py index b74e15dc5..4a4f07531 100644 --- a/modules/face/faceid.py +++ b/modules/face/faceid.py @@ -204,7 +204,6 @@ def face_id( ip_model_dict["face_image"] = face_images ip_model_dict["faceid_embeds"] = face_embeds # overwrite placeholder faceid_model.set_scale(scale) - extra_network_data = None if p.all_prompts is None or len(p.all_prompts) == 0: processing.process_init(p) @@ -215,11 +214,9 @@ def face_id( p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n+1) * p.batch_size] p.seeds = p.all_seeds[n * p.batch_size:(n+1) * p.batch_size] p.subseeds = p.all_subseeds[n * p.batch_size:(n+1) * p.batch_size] - p.prompts, extra_network_data = extra_networks.parse_prompts(p.prompts) + p.prompts, p.network_data = extra_networks.parse_prompts(p.prompts) - if not p.disable_extra_networks: - with devices.autocast(): - extra_networks.activate(p, extra_network_data) + extra_networks.activate(p, p.network_data) ip_model_dict.update({ "prompt": p.prompts[0], "negative_prompt": p.negative_prompts[0], @@ -239,8 +236,7 @@ def face_id( devices.torch_gc() ipadapter.unapply(p.sd_model) - if not p.disable_extra_networks: - extra_networks.deactivate(p, extra_network_data) + extra_networks.deactivate(p, p.network_data) p.extra_generation_params["IP Adapter"] = f"{basename}:{scale}" finally: diff --git a/modules/face/instantid_model.py b/modules/face/instantid_model.py index 51a4d7850..543b39ded 100644 --- a/modules/face/instantid_model.py +++ b/modules/face/instantid_model.py @@ -344,7 +344,6 @@ def __call__( return hidden_states def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): - # TODO attention_mask query = query.contiguous() key = key.contiguous() value = value.contiguous() diff --git a/modules/face/photomaker_model.py b/modules/face/photomaker_model.py index b62fe73b8..3595c6a36 100644 --- a/modules/face/photomaker_model.py +++ b/modules/face/photomaker_model.py @@ -244,7 +244,7 @@ def encode_prompt_with_trigger_word( prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - input_ids = tokenizer.encode(prompt) # TODO: batch encode + input_ids = tokenizer.encode(prompt) clean_index = 0 clean_input_ids = [] class_token_index = [] @@ -296,7 +296,7 @@ def encode_prompt_with_trigger_word( prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case + class_tokens_mask = class_tokens_mask.to(device=device) return prompt_embeds, pooled_prompt_embeds, class_tokens_mask @@ -332,7 +332,7 @@ def __call__( callback_steps: int = 1, # Added parameters (for PhotoMaker) input_id_images: PipelineImageInput = None, - start_merge_step: int = 0, # TODO: change to `style_strength_ratio` in the future + start_merge_step: int = 0, class_tokens_mask: Optional[torch.LongTensor] = None, prompt_embeds_text_only: Optional[torch.FloatTensor] = None, pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None, @@ -410,7 +410,7 @@ def __call__( ( prompt_embeds_text_only, negative_prompt_embeds, - pooled_prompt_embeds_text_only, # TODO: replace the pooled_prompt_embeds with text only prompt + pooled_prompt_embeds_text_only, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt_text_only, @@ -431,7 +431,7 @@ def __call__( if not isinstance(input_id_images[0], torch.Tensor): id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values - id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts + id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # 6. Get the update text embedding with the stacked ID embedding prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask) diff --git a/modules/files_cache.py b/modules/files_cache.py index fa2241afc..d65e0f4f4 100644 --- a/modules/files_cache.py +++ b/modules/files_cache.py @@ -6,6 +6,7 @@ from installer import log +do_cache_folders = os.environ.get('SD_NO_CACHE', None) is None class Directory: # forward declaration ... @@ -87,8 +88,6 @@ def is_stale(self) -> bool: return not self.is_directory or self.mtime != self.live_mtime - - class DirectoryCache(UserDict, DirectoryCollection): def __delattr__(self, directory_path: str) -> None: directory: Directory = get_directory(directory_path, fetch=False) @@ -126,7 +125,7 @@ def clean_directory(directory: Directory, /, recursive: RecursiveType=False) -> return is_clean -def get_directory(directory_or_path: str, /, fetch:bool=True) -> Union[Directory, None]: +def get_directory(directory_or_path: str, /, fetch: bool=True) -> Union[Directory, None]: if isinstance(directory_or_path, Directory): if directory_or_path.is_directory: return directory_or_path @@ -136,8 +135,9 @@ def get_directory(directory_or_path: str, /, fetch:bool=True) -> Union[Directory if not cache_folders.get(directory_or_path, None): if fetch: directory = fetch_directory(directory_path=directory_or_path) - if directory: + if directory and do_cache_folders: cache_folders[directory_or_path] = directory + return directory else: clean_directory(cache_folders[directory_or_path]) return cache_folders[directory_or_path] if directory_or_path in cache_folders else None diff --git a/modules/freescale/__init__.py b/modules/freescale/__init__.py new file mode 100644 index 000000000..7b9c17f5d --- /dev/null +++ b/modules/freescale/__init__.py @@ -0,0 +1,4 @@ +# Credits: https://github.com/ali-vilab/FreeScale + +from .freescale_pipeline import StableDiffusionXLFreeScale +from .freescale_pipeline_img2img import StableDiffusionXLFreeScaleImg2Img diff --git a/modules/freescale/free_lunch_utils.py b/modules/freescale/free_lunch_utils.py new file mode 100644 index 000000000..be26b732a --- /dev/null +++ b/modules/freescale/free_lunch_utils.py @@ -0,0 +1,305 @@ +from typing import Any, Dict, Optional, Tuple +import torch +import torch.fft as fft +from diffusers.utils import is_torch_version + +""" Borrowed from https://github.com/ChenyangSi/FreeU/blob/main/demo/free_lunch_utils.py +""" + +def isinstance_str(x: object, cls_name: str): + """ + Checks whether x has any class *named* cls_name in its ancestry. + Doesn't require access to the class's implementation. + + Useful for patching! + """ + + for _cls in x.__class__.__mro__: + if _cls.__name__ == cls_name: + return True + + return False + + +def Fourier_filter(x, threshold, scale): + dtype = x.dtype + x = x.type(torch.float32) + # FFT + x_freq = fft.fftn(x, dim=(-2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-2, -1)) + + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W)).cuda() + + crow, ccol = H // 2, W //2 + mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale + x_freq = x_freq * mask + + # IFFT + x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real + + x_filtered = x_filtered.type(dtype) + return x_filtered + + +def register_upblock2d(model): + def up_forward(self): + def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + #print(f"in upblock2d, hidden states shape: {hidden_states.shape}") + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "UpBlock2D"): + upsample_block.forward = up_forward(upsample_block) + + +def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): + def up_forward(self): + def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}") + + # --------------- FreeU code ----------------------- + # Only operate on the first two stages + if hidden_states.shape[1] == 1280: + hidden_states[:,:640] = hidden_states[:,:640] * self.b1 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1) + if hidden_states.shape[1] == 640: + hidden_states[:,:320] = hidden_states[:,:320] * self.b2 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2) + # --------------------------------------------------------- + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "UpBlock2D"): + upsample_block.forward = up_forward(upsample_block) + setattr(upsample_block, 'b1', b1) + setattr(upsample_block, 'b2', b2) + setattr(upsample_block, 's1', s1) + setattr(upsample_block, 's2', s2) + + +def register_crossattn_upblock2d(model): + def up_forward(self): + def forward( + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}") + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "CrossAttnUpBlock2D"): + upsample_block.forward = up_forward(upsample_block) + + +def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2): + def up_forward(self): + def forward( + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}") + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # --------------- FreeU code ----------------------- + # Only operate on the first two stages + if hidden_states.shape[1] == 1280: + hidden_states[:,:640] = hidden_states[:,:640] * self.b1 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1) + if hidden_states.shape[1] == 640: + hidden_states[:,:320] = hidden_states[:,:320] * self.b2 + res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2) + # --------------------------------------------------------- + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + # hidden_states = attn( + # hidden_states, + # encoder_hidden_states=encoder_hidden_states, + # cross_attention_kwargs=cross_attention_kwargs, + # encoder_attention_mask=encoder_attention_mask, + # return_dict=False, + # )[0] + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + return forward + + for i, upsample_block in enumerate(model.unet.up_blocks): + if isinstance_str(upsample_block, "CrossAttnUpBlock2D"): + upsample_block.forward = up_forward(upsample_block) + setattr(upsample_block, 'b1', b1) + setattr(upsample_block, 'b2', b2) + setattr(upsample_block, 's1', s1) + setattr(upsample_block, 's2', s2) diff --git a/modules/freescale/freescale_pipeline.py b/modules/freescale/freescale_pipeline.py new file mode 100644 index 000000000..9b7a68b68 --- /dev/null +++ b/modules/freescale/freescale_pipeline.py @@ -0,0 +1,1189 @@ +from inspect import isfunction +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import inspect +import os +import random + +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils import is_accelerate_available, is_accelerate_version, logging, replace_example_docstring +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.models.attention import BasicTransformerBlock + +from .scale_attention import ori_forward, scale_forward + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +def exists(val): + return val is not None + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + +to_torch = partial(torch.tensor, dtype=torch.float16) +betas = make_beta_schedule("linear", 1000, linear_start=0.00085, linear_end=0.012) +alphas = 1. - betas +alphas_cumprod = np.cumprod(alphas, axis=0) +sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod)) +sqrt_one_minus_alphas_cumprod = to_torch(np.sqrt(1. - alphas_cumprod)) + +def q_sample(x_start, t, init_noise_sigma = 1.0, noise=None, device=None): + noise = default(noise, lambda: torch.randn_like(x_start)).to(device) * init_noise_sigma + return (extract_into_tensor(sqrt_alphas_cumprod.to(device), t, x_start.shape) * x_start + + extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(device), t, x_start.shape) * noise) + +def get_views(height, width, h_window_size=128, w_window_size=128, h_window_stride=64, w_window_stride=64, vae_scale_factor=8): + height //= vae_scale_factor + width //= vae_scale_factor + num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1 + num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1 + total_num_blocks = int(num_blocks_height * num_blocks_width) + views = [] + for i in range(total_num_blocks): + h_start = int((i // num_blocks_width) * h_window_stride) + h_end = h_start + h_window_size + w_start = int((i % num_blocks_width) * w_window_stride) + w_end = w_start + w_window_size + + if h_end > height: + h_start = int(h_start + height - h_end) + h_end = int(height) + if w_end > width: + w_start = int(w_start + width - w_end) + w_end = int(width) + if h_start < 0: + h_end = int(h_end - h_start) + h_start = 0 + if w_start < 0: + w_end = int(w_end - w_start) + w_start = 0 + + random_jitter = True + if random_jitter: + h_jitter_range = (h_window_size - h_window_stride) // 4 + w_jitter_range = (w_window_size - w_window_stride) // 4 + h_jitter = 0 + w_jitter = 0 + + if (w_start != 0) and (w_end != width): + w_jitter = random.randint(-w_jitter_range, w_jitter_range) + elif (w_start == 0) and (w_end != width): + w_jitter = random.randint(-w_jitter_range, 0) + elif (w_start != 0) and (w_end == width): + w_jitter = random.randint(0, w_jitter_range) + if (h_start != 0) and (h_end != height): + h_jitter = random.randint(-h_jitter_range, h_jitter_range) + elif (h_start == 0) and (h_end != height): + h_jitter = random.randint(-h_jitter_range, 0) + elif (h_start != 0) and (h_end == height): + h_jitter = random.randint(0, h_jitter_range) + h_start += (h_jitter + h_jitter_range) + h_end += (h_jitter + h_jitter_range) + w_start += (w_jitter + w_jitter_range) + w_end += (w_jitter + w_jitter_range) + + views.append((h_start, h_end, w_start, w_end)) + return views + +def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3): + x_coord = torch.arange(kernel_size) + gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] + kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) + + return kernel + +def gaussian_filter(latents, kernel_size=3, sigma=1.0): + channels = latents.shape[1] + kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) + if len(latents.shape) == 5: + b = latents.shape[0] + latents = rearrange(latents, 'b c t i j -> (b t) c i j') + blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) + blurred_latents = rearrange(blurred_latents, '(b t) c i j -> b c t i j', b=b) + else: + blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) + + return blurred_latents + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLFreeScale(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + model_sequence = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + model_sequence.extend([self.unet, self.vae]) + + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt, negative_prompt_2] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + resolutions_list: Optional[Union[int, List[int]]] = None, + restart_steps: Optional[Union[int, List[int]]] = None, + cosine_scale: float = 2.0, + dilate_tau: int = 35, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + # 0. Default height and width to unet + if resolutions_list: + height, width = resolutions_list[0] + target_sizes = resolutions_list[1:] + if not restart_steps: + restart_steps = [15] * len(target_sizes) + else: + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 9.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + results_list = [] + + for block in self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks: + for module in block.modules(): + if isinstance(module, BasicTransformerBlock): + module.forward = ori_forward.__get__(module, BasicTransformerBlock) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + results_list.append(latents) + + for restart_index, target_size in enumerate(target_sizes): + restart_step = restart_steps[restart_index] + target_size_ = [target_size[0]//8, target_size[1]//8] + + for block in self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks: + for module in block.modules(): + if isinstance(module, BasicTransformerBlock): + module.forward = scale_forward.__get__(module, BasicTransformerBlock) + module.current_hw = target_size + + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + latents = latents / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = torch.nn.functional.interpolate( + image, + size=target_size, + mode='bicubic', + ) + latents = self.vae.encode(image).latent_dist.sample().to(self.vae.dtype) + latents = latents * self.vae.config.scaling_factor + + noise_latents = [] + noise = torch.randn_like(latents) + for timestep in self.scheduler.timesteps: + noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0)) + noise_latents.append(noise_latent) + latents = noise_latents[restart_step] + + self.scheduler._step_index = 0 + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + + if i < restart_step: + self.scheduler._step_index += 1 + progress_bar.update() + continue + + cosine_factor = 0.5 * (1 + torch.cos(torch.pi * (self.scheduler.config.num_train_timesteps - t) / self.scheduler.config.num_train_timesteps)).cpu() + c1 = cosine_factor ** cosine_scale + latents = latents * (1 - c1) + noise_latents[i] * c1 + + dilate_coef=target_size[1]//1024 + + dilate_layers = [ + # "down_blocks.1.resnets.0.conv1", + # "down_blocks.1.resnets.0.conv2", + # "down_blocks.1.resnets.1.conv1", + # "down_blocks.1.resnets.1.conv2", + "down_blocks.1.downsamplers.0.conv", + "down_blocks.2.resnets.0.conv1", + "down_blocks.2.resnets.0.conv2", + "down_blocks.2.resnets.1.conv1", + "down_blocks.2.resnets.1.conv2", + # "up_blocks.0.resnets.0.conv1", + # "up_blocks.0.resnets.0.conv2", + # "up_blocks.0.resnets.1.conv1", + # "up_blocks.0.resnets.1.conv2", + # "up_blocks.0.resnets.2.conv1", + # "up_blocks.0.resnets.2.conv2", + # "up_blocks.0.upsamplers.0.conv", + # "up_blocks.1.resnets.0.conv1", + # "up_blocks.1.resnets.0.conv2", + # "up_blocks.1.resnets.1.conv1", + # "up_blocks.1.resnets.1.conv2", + # "up_blocks.1.resnets.2.conv1", + # "up_blocks.1.resnets.2.conv2", + # "up_blocks.1.upsamplers.0.conv", + # "up_blocks.2.resnets.0.conv1", + # "up_blocks.2.resnets.0.conv2", + # "up_blocks.2.resnets.1.conv1", + # "up_blocks.2.resnets.1.conv2", + # "up_blocks.2.resnets.2.conv1", + # "up_blocks.2.resnets.2.conv2", + "mid_block.resnets.0.conv1", + "mid_block.resnets.0.conv2", + "mid_block.resnets.1.conv1", + "mid_block.resnets.1.conv2" + ] + + for name, module in self.unet.named_modules(): + if name in dilate_layers: + if i < dilate_tau: + module.dilation = (dilate_coef, dilate_coef) + module.padding = (dilate_coef, dilate_coef) + else: + module.dilation = (1, 1) + module.padding = (1, 1) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + for name, module in self.unet.named_modules(): + # if ('.conv' in name) and ('.conv_' not in name): + if name in dilate_layers: + module.dilation = (1, 1) + module.padding = (1, 1) + + results_list.append(latents) + + """ + final_results = [] + for latents in results_list: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + final_results += [(image,)] + else: + final_results += [StableDiffusionXLPipelineOutput(images=image)] + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return final_results + """ + return StableDiffusionXLPipelineOutput(images=results_list) + + # Overrride to properly handle the loading and unloading of the additional text encoder. + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + # We could have accessed the unet config from `lora_state_dict()` too. We pass + # it here explicitly to be able to tell that it's coming from an SDXL + # pipeline. + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + ) + + @classmethod + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + state_dict = {} + + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + state_dict.update(pack_weights(unet_lora_layers, "unet")) + + if text_encoder_lora_layers and text_encoder_2_lora_layers: + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + self.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def _remove_text_encoder_monkey_patch(self): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) diff --git a/modules/freescale/freescale_pipeline_img2img.py b/modules/freescale/freescale_pipeline_img2img.py new file mode 100644 index 000000000..df4c3f0c1 --- /dev/null +++ b/modules/freescale/freescale_pipeline_img2img.py @@ -0,0 +1,1245 @@ +from inspect import isfunction +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import inspect +import os +import random + +from PIL import Image +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +import torchvision.transforms as transforms + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils import is_accelerate_available, is_accelerate_version, logging, replace_example_docstring +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from diffusers.models.attention import BasicTransformerBlock + +from .scale_attention import ori_forward, scale_forward + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + +def process_image_to_tensor(image): + image = image.convert("RGB") + # image = Image.open(image_path).convert("RGB") + transform = transforms.Compose( + [ + # transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ] + ) + image_tensor = transform(image) + return image_tensor + +def process_image_to_bitensor(image): + # image = Image.open(image_path).convert("L") + image = image.convert("L") + transform = transforms.ToTensor() + image_tensor = transform(image) + binary_tensor = torch.where(image_tensor != 0, torch.tensor(1.0), torch.tensor(0.0)) + return binary_tensor + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +def exists(val): + return val is not None + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + +to_torch = partial(torch.tensor, dtype=torch.float16) +betas = make_beta_schedule("linear", 1000, linear_start=0.00085, linear_end=0.012) +alphas = 1. - betas +alphas_cumprod = np.cumprod(alphas, axis=0) +sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod)) +sqrt_one_minus_alphas_cumprod = to_torch(np.sqrt(1. - alphas_cumprod)) + +def q_sample(x_start, t, init_noise_sigma = 1.0, noise=None, device=None): + noise = default(noise, lambda: torch.randn_like(x_start)).to(device) * init_noise_sigma + return (extract_into_tensor(sqrt_alphas_cumprod.to(device), t, x_start.shape) * x_start + + extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(device), t, x_start.shape) * noise) + +def get_views(height, width, h_window_size=128, w_window_size=128, h_window_stride=64, w_window_stride=64, vae_scale_factor=8): + height //= vae_scale_factor + width //= vae_scale_factor + num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1 + num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1 + total_num_blocks = int(num_blocks_height * num_blocks_width) + views = [] + for i in range(total_num_blocks): + h_start = int((i // num_blocks_width) * h_window_stride) + h_end = h_start + h_window_size + w_start = int((i % num_blocks_width) * w_window_stride) + w_end = w_start + w_window_size + + if h_end > height: + h_start = int(h_start + height - h_end) + h_end = int(height) + if w_end > width: + w_start = int(w_start + width - w_end) + w_end = int(width) + if h_start < 0: + h_end = int(h_end - h_start) + h_start = 0 + if w_start < 0: + w_end = int(w_end - w_start) + w_start = 0 + + random_jitter = True + if random_jitter: + h_jitter_range = (h_window_size - h_window_stride) // 4 + w_jitter_range = (w_window_size - w_window_stride) // 4 + h_jitter = 0 + w_jitter = 0 + + if (w_start != 0) and (w_end != width): + w_jitter = random.randint(-w_jitter_range, w_jitter_range) + elif (w_start == 0) and (w_end != width): + w_jitter = random.randint(-w_jitter_range, 0) + elif (w_start != 0) and (w_end == width): + w_jitter = random.randint(0, w_jitter_range) + if (h_start != 0) and (h_end != height): + h_jitter = random.randint(-h_jitter_range, h_jitter_range) + elif (h_start == 0) and (h_end != height): + h_jitter = random.randint(-h_jitter_range, 0) + elif (h_start != 0) and (h_end == height): + h_jitter = random.randint(0, h_jitter_range) + h_start += (h_jitter + h_jitter_range) + h_end += (h_jitter + h_jitter_range) + w_start += (w_jitter + w_jitter_range) + w_end += (w_jitter + w_jitter_range) + + views.append((h_start, h_end, w_start, w_end)) + return views + +def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3): + x_coord = torch.arange(kernel_size) + gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] + kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) + + return kernel + +def gaussian_filter(latents, kernel_size=3, sigma=1.0): + channels = latents.shape[1] + kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) + if len(latents.shape) == 5: + b = latents.shape[0] + latents = rearrange(latents, 'b c t i j -> (b t) c i j') + blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) + blurred_latents = rearrange(blurred_latents, '(b t) c i j -> b c t i j', b=b) + else: + blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) + + return blurred_latents + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLFreeScaleImg2Img(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size + + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + model_sequence = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + model_sequence.extend([self.unet, self.vae]) + + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt, negative_prompt_2] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + resolutions_list: Optional[Union[int, List[int]]] = None, + restart_steps: Optional[Union[int, List[int]]] = None, + cosine_scale: float = 2.0, + cosine_scale_bg: float = 1.0, + dilate_tau: int = 35, + img_path: Optional[str] = "", + mask_path: Optional[str] = "", + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + + # 0. Default height and width to unet + if resolutions_list: + height, width = resolutions_list[0] + target_sizes = resolutions_list[1:] + if not restart_steps: + restart_steps = [15] * len(target_sizes) + else: + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 9.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + results_list = [] + + for block in self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks: + for module in block.modules(): + if isinstance(module, BasicTransformerBlock): + module.forward = ori_forward.__get__(module, BasicTransformerBlock) + + if img_path != '': + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + input_image = process_image_to_tensor(img_path).unsqueeze(0).to(dtype=self.vae.dtype, device=device) + latents = self.vae.encode(input_image).latent_dist.sample().to(self.vae.dtype) + latents = latents * self.vae.config.scaling_factor + else: + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + results_list.append(latents) + + if mask_path != '': + mask = process_image_to_bitensor(mask_path).unsqueeze(0) + + for restart_index, target_size in enumerate(target_sizes): + restart_step = restart_steps[restart_index] + target_size_ = [target_size[0]//8, target_size[1]//8] + + for block in self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks: + for module in block.modules(): + if isinstance(module, BasicTransformerBlock): + module.forward = scale_forward.__get__(module, BasicTransformerBlock) + module.current_hw = target_size + + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + latents = latents / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = torch.nn.functional.interpolate( + image, + size=target_size, + mode='bicubic', + ) + latents = self.vae.encode(image).latent_dist.sample().to(self.vae.dtype) + latents = latents * self.vae.config.scaling_factor + + if mask_path != '': + mask_ = torch.nn.functional.interpolate( + mask, + size=target_size_, + mode="nearest", + ).to(device) + + noise_latents = [] + noise = torch.randn_like(latents) + for timestep in self.scheduler.timesteps: + noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0)) + noise_latents.append(noise_latent) + latents = noise_latents[restart_step] + + self.scheduler._step_index = 0 + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + + if i < restart_step: + self.scheduler._step_index += 1 + progress_bar.update() + continue + + cosine_factor = 0.5 * (1 + torch.cos(torch.pi * (self.scheduler.config.num_train_timesteps - t) / self.scheduler.config.num_train_timesteps)).cpu() + if mask_path != '': + c1 = (cosine_factor ** (mask_ * cosine_scale + (1-mask_) * cosine_scale_bg)).to(dtype=torch.float16) + else: + c1 = cosine_factor ** cosine_scale + latents = latents * (1 - c1) + noise_latents[i] * c1 + + dilate_coef=target_size[1]//1024 + + dilate_layers = [ + # "down_blocks.1.resnets.0.conv1", + # "down_blocks.1.resnets.0.conv2", + # "down_blocks.1.resnets.1.conv1", + # "down_blocks.1.resnets.1.conv2", + "down_blocks.1.downsamplers.0.conv", + "down_blocks.2.resnets.0.conv1", + "down_blocks.2.resnets.0.conv2", + "down_blocks.2.resnets.1.conv1", + "down_blocks.2.resnets.1.conv2", + # "up_blocks.0.resnets.0.conv1", + # "up_blocks.0.resnets.0.conv2", + # "up_blocks.0.resnets.1.conv1", + # "up_blocks.0.resnets.1.conv2", + # "up_blocks.0.resnets.2.conv1", + # "up_blocks.0.resnets.2.conv2", + # "up_blocks.0.upsamplers.0.conv", + # "up_blocks.1.resnets.0.conv1", + # "up_blocks.1.resnets.0.conv2", + # "up_blocks.1.resnets.1.conv1", + # "up_blocks.1.resnets.1.conv2", + # "up_blocks.1.resnets.2.conv1", + # "up_blocks.1.resnets.2.conv2", + # "up_blocks.1.upsamplers.0.conv", + # "up_blocks.2.resnets.0.conv1", + # "up_blocks.2.resnets.0.conv2", + # "up_blocks.2.resnets.1.conv1", + # "up_blocks.2.resnets.1.conv2", + # "up_blocks.2.resnets.2.conv1", + # "up_blocks.2.resnets.2.conv2", + "mid_block.resnets.0.conv1", + "mid_block.resnets.0.conv2", + "mid_block.resnets.1.conv1", + "mid_block.resnets.1.conv2" + ] + + for name, module in self.unet.named_modules(): + if name in dilate_layers: + if i < dilate_tau: + module.dilation = (dilate_coef, dilate_coef) + module.padding = (dilate_coef, dilate_coef) + else: + module.dilation = (1, 1) + module.padding = (1, 1) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + for name, module in self.unet.named_modules(): + # if ('.conv' in name) and ('.conv_' not in name): + if name in dilate_layers: + module.dilation = (1, 1) + module.padding = (1, 1) + + results_list.append(latents) + + """ + final_results = [] + for latents in results_list: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + final_results += [(image,)] + else: + final_results += [StableDiffusionXLPipelineOutput(images=image)] + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return final_results + """ + return StableDiffusionXLPipelineOutput(images=results_list) + + # Overrride to properly handle the loading and unloading of the additional text encoder. + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + # We could have accessed the unet config from `lora_state_dict()` too. We pass + # it here explicitly to be able to tell that it's coming from an SDXL + # pipeline. + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + ) + + @classmethod + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + state_dict = {} + + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + state_dict.update(pack_weights(unet_lora_layers, "unet")) + + if text_encoder_lora_layers and text_encoder_2_lora_layers: + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + self.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def _remove_text_encoder_monkey_patch(self): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) diff --git a/modules/freescale/scale_attention.py b/modules/freescale/scale_attention.py new file mode 100644 index 000000000..9e83d5067 --- /dev/null +++ b/modules/freescale/scale_attention.py @@ -0,0 +1,367 @@ +from typing import Any, Dict, Optional +import random +import torch +import torch.nn.functional as F +from einops import rearrange + + +def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3): + x_coord = torch.arange(kernel_size) + gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] + kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) + + return kernel + +def gaussian_filter(latents, kernel_size=3, sigma=1.0): + channels = latents.shape[1] + kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) + blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) + + return blurred_latents + +def get_views(height, width, h_window_size=128, w_window_size=128, scale_factor=8): + height = int(height) + width = int(width) + h_window_stride = h_window_size // 2 + w_window_stride = w_window_size // 2 + h_window_size = int(h_window_size / scale_factor) + w_window_size = int(w_window_size / scale_factor) + h_window_stride = int(h_window_stride / scale_factor) + w_window_stride = int(w_window_stride / scale_factor) + num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1 + num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1 + total_num_blocks = int(num_blocks_height * num_blocks_width) + views = [] + for i in range(total_num_blocks): + h_start = int((i // num_blocks_width) * h_window_stride) + h_end = h_start + h_window_size + w_start = int((i % num_blocks_width) * w_window_stride) + w_end = w_start + w_window_size + + if h_end > height: + h_start = int(h_start + height - h_end) + h_end = int(height) + if w_end > width: + w_start = int(w_start + width - w_end) + w_end = int(width) + if h_start < 0: + h_end = int(h_end - h_start) + h_start = 0 + if w_start < 0: + w_end = int(w_end - w_start) + w_start = 0 + + random_jitter = True + if random_jitter: + h_jitter_range = h_window_size // 8 + w_jitter_range = w_window_size // 8 + h_jitter = 0 + w_jitter = 0 + + if (w_start != 0) and (w_end != width): + w_jitter = random.randint(-w_jitter_range, w_jitter_range) + elif (w_start == 0) and (w_end != width): + w_jitter = random.randint(-w_jitter_range, 0) + elif (w_start != 0) and (w_end == width): + w_jitter = random.randint(0, w_jitter_range) + if (h_start != 0) and (h_end != height): + h_jitter = random.randint(-h_jitter_range, h_jitter_range) + elif (h_start == 0) and (h_end != height): + h_jitter = random.randint(-h_jitter_range, 0) + elif (h_start != 0) and (h_end == height): + h_jitter = random.randint(0, h_jitter_range) + h_start += (h_jitter + h_jitter_range) + h_end += (h_jitter + h_jitter_range) + w_start += (w_jitter + w_jitter_range) + w_end += (w_jitter + w_jitter_range) + + views.append((h_start, h_end, w_start, w_end)) + return views + +def scale_forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, +): + # Notice that normalization is always applied before the real computation in the following blocks. + if self.current_hw: + current_scale_num_h, current_scale_num_w = max(self.current_hw[0] // 1024, 1), max(self.current_hw[1] // 1024, 1) + else: + current_scale_num_h, current_scale_num_w = 1, 1 + + # 0. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + ratio_hw = current_scale_num_h / current_scale_num_w + latent_h = int((norm_hidden_states.shape[1] * ratio_hw) ** 0.5) + latent_w = int(latent_h / ratio_hw) + scale_factor = 128 * current_scale_num_h / latent_h + if ratio_hw > 1: + sub_h = 128 + sub_w = int(128 / ratio_hw) + else: + sub_h = int(128 * ratio_hw) + sub_w = 128 + + h_jitter_range = int(sub_h / scale_factor // 8) + w_jitter_range = int(sub_w / scale_factor // 8) + views = get_views(latent_h, latent_w, sub_h, sub_w, scale_factor = scale_factor) + + current_scale_num = max(current_scale_num_h, current_scale_num_w) + global_views = [[h, w] for h in range(current_scale_num_h) for w in range(current_scale_num_w)] + + four_window = True + fourg_window = False + + if four_window: + norm_hidden_states_ = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h) + norm_hidden_states_ = F.pad(norm_hidden_states_, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0) + value = torch.zeros_like(norm_hidden_states_) + count = torch.zeros_like(norm_hidden_states_) + for index, view in enumerate(views): + h_start, h_end, w_start, w_end = view + local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :] + local_states = rearrange(local_states, 'bh h w d -> bh (h w) d') + local_output = self.attn1( + local_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor)) + + value[:, h_start:h_end, w_start:w_end, :] += local_output * 1 + count[:, h_start:h_end, w_start:w_end, :] += 1 + + value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :] + count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :] + attn_output = torch.where(count>0, value/count, value) + + gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0) + + attn_output_global = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h) + + gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0) + + attn_output = gaussian_local + (attn_output_global - gaussian_global) + attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d') + + elif fourg_window: + norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h) + norm_hidden_states_ = F.pad(norm_hidden_states, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0) + value = torch.zeros_like(norm_hidden_states_) + count = torch.zeros_like(norm_hidden_states_) + for index, view in enumerate(views): + h_start, h_end, w_start, w_end = view + local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :] + local_states = rearrange(local_states, 'bh h w d -> bh (h w) d') + local_output = self.attn1( + local_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor)) + + value[:, h_start:h_end, w_start:w_end, :] += local_output * 1 + count[:, h_start:h_end, w_start:w_end, :] += 1 + + value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :] + count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :] + attn_output = torch.where(count>0, value/count, value) + + gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0) + + value = torch.zeros_like(norm_hidden_states) + count = torch.zeros_like(norm_hidden_states) + for index, global_view in enumerate(global_views): + h, w = global_view + global_states = norm_hidden_states[:, h::current_scale_num_h, w::current_scale_num_w, :] + global_states = rearrange(global_states, 'bh h w d -> bh (h w) d') + global_output = self.attn1( + global_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + global_output = rearrange(global_output, 'bh (h w) d -> bh h w d', h = int(global_output.shape[1] ** 0.5)) + + value[:, h::current_scale_num_h, w::current_scale_num_w, :] += global_output * 1 + count[:, h::current_scale_num_h, w::current_scale_num_w, :] += 1 + + attn_output_global = torch.where(count>0, value/count, value) + + gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0) + + attn_output = gaussian_local + (attn_output_global - gaussian_global) + attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d') + + else: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + # 2.5 ends + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice) + for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + +def ori_forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, +): + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + # 2.5 ends + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice) + for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states diff --git a/modules/ggml/__init__.py b/modules/ggml/__init__.py index acac6057c..44721d846 100644 --- a/modules/ggml/__init__.py +++ b/modules/ggml/__init__.py @@ -1,11 +1,35 @@ -from pathlib import Path +import os +import time import torch -import gguf -from .gguf_utils import TORCH_COMPATIBLE_QTYPES -from .gguf_tensor import GGMLTensor +import diffusers +import transformers -def load_gguf_state_dict(path: str, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]: +def install_gguf(): + # pip install git+https://github.com/junejae/transformers@feature/t5-gguf + # https://github.com/ggerganov/llama.cpp/issues/9566 + from installer import install + install('gguf', quiet=True) + import importlib + import gguf + from modules import shared + scripts_dir = os.path.join(os.path.dirname(gguf.__file__), '..', 'scripts') + if os.path.exists(scripts_dir): + os.rename(scripts_dir, scripts_dir + str(time.time())) + # monkey patch transformers/diffusers so they detect newly installed gguf pacakge correctly + ver = importlib.metadata.version('gguf') + transformers.utils.import_utils._is_gguf_available = True # pylint: disable=protected-access + transformers.utils.import_utils._gguf_version = ver # pylint: disable=protected-access + diffusers.utils.import_utils._is_gguf_available = True # pylint: disable=protected-access + diffusers.utils.import_utils._gguf_version = ver # pylint: disable=protected-access + shared.log.debug(f'Load GGUF: version={ver}') + return gguf + + +def load_gguf_state_dict(path: str, compute_dtype: torch.dtype) -> dict: + gguf = install_gguf() + from .gguf_utils import TORCH_COMPATIBLE_QTYPES + from .gguf_tensor import GGMLTensor sd: dict[str, GGMLTensor] = {} stats = {} reader = gguf.GGUFReader(path) @@ -19,3 +43,14 @@ def load_gguf_state_dict(path: str, compute_dtype: torch.dtype) -> dict[str, GGM stats[tensor.tensor_type.name] = 0 stats[tensor.tensor_type.name] += 1 return sd, stats + + +def load_gguf(path, cls, compute_dtype: torch.dtype): + _gguf = install_gguf() + module = cls.from_single_file( + path, + quantization_config = diffusers.GGUFQuantizationConfig(compute_dtype=compute_dtype), + torch_dtype=compute_dtype, + ) + module.gguf = 'gguf' + return module diff --git a/modules/ggml/gguf_tensor.py b/modules/ggml/gguf_tensor.py index 4bc9117cb..8b2f608ac 100644 --- a/modules/ggml/gguf_tensor.py +++ b/modules/ggml/gguf_tensor.py @@ -131,7 +131,6 @@ def get_dequantized_tensor(self): if self._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES: return self.quantized_data.to(self.compute_dtype) elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS: - # TODO(ryand): Look into how the dtype param is intended to be used. return dequantize( data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self.tensor_shape, dtype=None ).to(self.compute_dtype) diff --git a/modules/hashes.py b/modules/hashes.py index cf83794b0..a003f4840 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -9,6 +9,13 @@ cache_data = None progress_ok = True + +def init_cache(): + global cache_data # pylint: disable=global-statement + if cache_data is None: + cache_data = {} if not os.path.isfile(cache_filename) else shared.readfile(cache_filename, lock=True) + + def dump_cache(): shared.writefile(cache_data, cache_filename) diff --git a/modules/hidiffusion/hidiffusion.py b/modules/hidiffusion/hidiffusion.py index 7874f03af..d6f68eb15 100644 --- a/modules/hidiffusion/hidiffusion.py +++ b/modules/hidiffusion/hidiffusion.py @@ -234,7 +234,7 @@ def window_reverse(windows, window_size, H, W, shift_size): norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp if self._chunk_size is not None: - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) # pylint: disable=undefined-variable # TODO hidiffusion undefined + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) # pylint: disable=undefined-variable else: ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: @@ -308,7 +308,7 @@ def forward( self.T1 = int(self.max_timestep * self.T1_ratio) output_states = () - _scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # TODO hidiffusion unused + _scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 blocks = list(zip(self.resnets, self.attentions)) @@ -407,7 +407,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - def fix_scale(first, second): # TODO hidiffusion breaks hidden_scale.shape on 3rd generate with sdxl + def fix_scale(first, second): if (first.shape[-1] != second.shape[-1] or first.shape[-2] != second.shape[-2]): rescale = min(second.shape[-2] / first.shape[-2], second.shape[-1] / first.shape[-1]) # log.debug(f"HiDiffusion rescale: {hidden_states.shape} => {res_hidden_states_tuple[0].shape} scale={rescale}") diff --git a/modules/images.py b/modules/images.py index 910349bef..2cfbe941d 100644 --- a/modules/images.py +++ b/modules/images.py @@ -267,6 +267,8 @@ def safe_decode_string(s: bytes): def read_info_from_image(image: Image, watermark: bool = False): + if image is None: + return '', {} items = image.info or {} geninfo = items.pop('parameters', None) or items.pop('UserComment', None) if geninfo is not None and len(geninfo) > 0: diff --git a/modules/images_namegen.py b/modules/images_namegen.py index bc58f728a..7cad39b92 100644 --- a/modules/images_namegen.py +++ b/modules/images_namegen.py @@ -45,12 +45,12 @@ class FilenameGenerator: 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8], 'sampler': lambda self: self.p and self.p.sampler_name, - 'seed': lambda self: self.seed and str(self.seed) or '', + 'seed': lambda self: (self.seed and str(self.seed)) or '', 'steps': lambda self: self.p and getattr(self.p, 'steps', 0), 'cfg': lambda self: self.p and getattr(self.p, 'cfg_scale', 0), 'clip_skip': lambda self: self.p and getattr(self.p, 'clip_skip', 0), 'denoising': lambda self: self.p and getattr(self.p, 'denoising_strength', 0), - 'styles': lambda self: self.p and ", ".join([style for style in self.p.styles if not style == "None"]) or "None", + 'styles': lambda self: (self.p and ", ".join([style for style in self.p.styles if not style == "None"])) or "None", 'uuid': lambda self: str(uuid.uuid4()), } default_time_format = '%Y%m%d%H%M%S' diff --git a/modules/images_resize.py b/modules/images_resize.py index d86ff6f22..362be79ee 100644 --- a/modules/images_resize.py +++ b/modules/images_resize.py @@ -5,13 +5,13 @@ from modules import shared -def resize_image(resize_mode, im, width, height, upscaler_name=None, output_type='image', context=None): +def resize_image(resize_mode: int, im: Image.Image, width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None): upscaler_name = upscaler_name or shared.opts.upscaler_for_img2img def latent(im, w, h, upscaler): from modules.processing_vae import vae_encode, vae_decode import torch - latents = vae_encode(im, shared.sd_model, full_quality=False) # TODO enable full VAE mode for resize-latent + latents = vae_encode(im, shared.sd_model, full_quality=False) # TODO resize image: enable full VAE mode for resize-latent latents = torch.nn.functional.interpolate(latents, size=(int(h // 8), int(w // 8)), mode=upscaler["mode"], antialias=upscaler["antialias"]) im = vae_decode(latents, shared.sd_model, output_type='pil', full_quality=False)[0] return im @@ -79,18 +79,18 @@ def fill(im, color=None): def context_aware(im, width, height, context): import seam_carving # https://github.com/li-plus/seam-carving - if 'forward' in context: + if 'forward' in context.lower(): energy_mode = "forward" - elif 'backward' in context: + elif 'backward' in context.lower(): energy_mode = "backward" else: return im - if 'Add' in context: + if 'add' in context.lower(): src_ratio = min(width / im.width, height / im.height) src_w = int(im.width * src_ratio) src_h = int(im.height * src_ratio) src_image = resize(im, src_w, src_h) - elif 'Remove' in context: + elif 'remove' in context.lower(): ratio = width / height src_ratio = im.width / im.height src_w = width if ratio > src_ratio else im.width * height // im.height @@ -122,7 +122,7 @@ def context_aware(im, width, height, context): from modules import masking res = fill(im, color=0) res, _mask = masking.outpaint(res) - elif resize_mode == 5: # context-aware + elif resize_mode == 5: # context-aware res = context_aware(im, width, height, context) else: res = im.copy() diff --git a/modules/img2img.py b/modules/img2img.py index 8274386cc..2e3eca54d 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -1,6 +1,7 @@ import os import itertools # SBM Batch frames import numpy as np +import filetype from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError import modules.scripts from modules import shared, processing, images @@ -8,7 +9,6 @@ from modules.ui import plaintext_to_html from modules.memstats import memory_stats - debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None debug('Trace: PROCESS') @@ -16,24 +16,25 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args): shared.log.debug(f'batch: {input_files}|{input_dir}|{output_dir}|{inpaint_mask_dir}') processing.fix_seed(p) + image_files = [] if input_files is not None and len(input_files) > 0: image_files = [f.name for f in input_files] - else: - if not os.path.isdir(input_dir): - shared.log.error(f"Process batch: directory not found: {input_dir}") - return - image_files = os.listdir(input_dir) - image_files = [os.path.join(input_dir, f) for f in image_files] + image_files = [f for f in image_files if filetype.is_image(f)] + shared.log.info(f'Process batch: input images={len(image_files)}') + elif os.path.isdir(input_dir): + image_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir)] + image_files = [f for f in image_files if filetype.is_image(f)] + shared.log.info(f'Process batch: input folder="{input_dir}" images={len(image_files)}') is_inpaint_batch = False - if inpaint_mask_dir: - inpaint_masks = os.listdir(inpaint_mask_dir) - inpaint_masks = [os.path.join(inpaint_mask_dir, f) for f in inpaint_masks] + if inpaint_mask_dir and os.path.isdir(inpaint_mask_dir): + inpaint_masks = [os.path.join(inpaint_mask_dir, f) for f in os.listdir(inpaint_mask_dir)] + inpaint_masks = [f for f in inpaint_masks if filetype.is_image(f)] is_inpaint_batch = len(inpaint_masks) > 0 - if is_inpaint_batch: - shared.log.info(f"Process batch: inpaint batch masks={len(inpaint_masks)}") + shared.log.info(f'Process batch: mask folder="{input_dir}" images={len(inpaint_masks)}') save_normally = output_dir == '' p.do_not_save_grid = True p.do_not_save_samples = not save_normally + p.default_prompt = p.prompt shared.state.job_count = len(image_files) * p.n_iter if shared.opts.batch_frame_mode: # SBM Frame mode is on, process each image in batch with same seed window_size = p.batch_size @@ -55,14 +56,29 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args) for image_file in batch_image_files: try: img = Image.open(image_file) - if p.scale_by != 1: - p.width = int(img.width * p.scale_by) - p.height = int(img.height * p.scale_by) + img = ImageOps.exif_transpose(img) + batch_images.append(img) + # p.init() + p.width = int(img.width * p.scale_by) + p.height = int(img.height * p.scale_by) + caption_file = os.path.splitext(image_file)[0] + '.txt' + prompt_type='default' + if os.path.exists(caption_file): + with open(caption_file, 'r', encoding='utf8') as f: + p.prompt = f.read() + prompt_type='file' + else: + p.prompt = p.default_prompt + p.all_prompts = None + p.all_negative_prompts = None + p.all_seeds = None + p.all_subseeds = None + shared.log.debug(f'Process batch: image="{image_file}" prompt={prompt_type} i={i+1}/{len(image_files)}') except UnidentifiedImageError as e: - shared.log.error(f"Image error: {e}") - continue - img = ImageOps.exif_transpose(img) - batch_images.append(img) + shared.log.error(f'Process batch: image="{image_file}" {e}') + if len(batch_images) == 0: + shared.log.warning("Process batch: no images found in batch") + continue batch_images = batch_images * btcrept # Standard mode sends the same image per batchsize. p.init_images = batch_images @@ -81,17 +97,20 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args) batch_image_files = batch_image_files * btcrept # List used for naming later. - proc = modules.scripts.scripts_img2img.run(p, *args) - if proc is None: - proc = processing.process_images(p) - for n, (image, image_file) in enumerate(itertools.zip_longest(proc.images,batch_image_files)): + processed = modules.scripts.scripts_img2img.run(p, *args) + if processed is None: + processed = processing.process_images(p) + + for n, (image, image_file) in enumerate(itertools.zip_longest(processed.images, batch_image_files)): + if image is None: + continue basename = '' if shared.opts.use_original_name_batch: forced_filename, ext = os.path.splitext(os.path.basename(image_file)) else: forced_filename = None ext = shared.opts.samples_format - if len(proc.images) > 1: + if len(processed.images) > 1: basename = f'{n + i}' if shared.opts.batch_frame_mode else f'{n}' else: basename = '' @@ -103,7 +122,7 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args) for k, v in items.items(): image.info[k] = v images.save_image(image, path=output_dir, basename=basename, seed=None, prompt=None, extension=ext, info=geninfo, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=image.info, forced_filename=forced_filename) - proc = modules.scripts.scripts_img2img.after(p, proc, *args) + processed = modules.scripts.scripts_img2img.after(p, processed, *args) shared.log.debug(f'Processed: images={len(batch_image_files)} memory={memory_stats()} batch') @@ -147,29 +166,20 @@ def img2img(id_task: str, state: str, mode: int, debug(f'img2img: id_task={id_task}|mode={mode}|prompt={prompt}|negative_prompt={negative_prompt}|prompt_styles={prompt_styles}|init_img={init_img}|sketch={sketch}|init_img_with_mask={init_img_with_mask}|inpaint_color_sketch={inpaint_color_sketch}|inpaint_color_sketch_orig={inpaint_color_sketch_orig}|init_img_inpaint={init_img_inpaint}|init_mask_inpaint={init_mask_inpaint}|steps={steps}|sampler_index={sampler_index}||mask_blur={mask_blur}|mask_alpha={mask_alpha}|inpainting_fill={inpainting_fill}|full_quality={full_quality}|detailer={detailer}|tiling={tiling}|hidiffusion={hidiffusion}|n_iter={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|image_cfg_scale={image_cfg_scale}|clip_skip={clip_skip}|denoising_strength={denoising_strength}|seed={seed}|subseed{subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}|selected_scale_tab={selected_scale_tab}|height={height}|width={width}|scale_by={scale_by}|resize_mode={resize_mode}|resize_name={resize_name}|resize_context={resize_context}|inpaint_full_res={inpaint_full_res}|inpaint_full_res_padding={inpaint_full_res_padding}|inpainting_mask_invert={inpainting_mask_invert}|img2img_batch_files={img2img_batch_files}|img2img_batch_input_dir={img2img_batch_input_dir}|img2img_batch_output_dir={img2img_batch_output_dir}|img2img_batch_inpaint_mask_dir={img2img_batch_inpaint_mask_dir}|override_settings_texts={override_settings_texts}') - if mode == 5: - if img2img_batch_files is None or len(img2img_batch_files) == 0: - shared.log.debug('Init bactch images not set') - elif init_img: - shared.log.debug('Init image not set') - if sampler_index is None: shared.log.warning('Sampler: invalid') sampler_index = 0 + mode = int(mode) + image = None + mask = None override_settings = create_override_settings_dict(override_settings_texts) - if mode == 0: # img2img + if mode == 0: # img2img if init_img is None: return [], '', '', 'Error: init image not provided' image = init_img.convert("RGB") - mask = None - elif mode == 1: # img2img sketch - if sketch is None: - return [], '', '', 'Error: sketch image not provided' - image = sketch.convert("RGB") - mask = None - elif mode == 2: # inpaint + elif mode == 1: # inpaint if init_img_with_mask is None: return [], '', '', 'Error: init image with mask not provided' image = init_img_with_mask["image"] @@ -177,7 +187,11 @@ def img2img(id_task: str, state: str, mode: int, alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') image = image.convert("RGB") - elif mode == 3: # inpaint sketch + elif mode == 2: # sketch + if sketch is None: + return [], '', '', 'Error: sketch image not provided' + image = sketch.convert("RGB") + elif mode == 3: # composite if inpaint_color_sketch is None: return [], '', '', 'Error: color sketch image not provided' image = inpaint_color_sketch @@ -188,15 +202,16 @@ def img2img(id_task: str, state: str, mode: int, blur = ImageFilter.GaussianBlur(mask_blur) image = Image.composite(image.filter(blur), orig, mask.filter(blur)) image = image.convert("RGB") - elif mode == 4: # inpaint upload mask + elif mode == 4: # inpaint upload mask if init_img_inpaint is None: return [], '', '', 'Error: inpaint image not provided' image = init_img_inpaint mask = init_mask_inpaint + elif mode == 5: # process batch + pass # handled later else: shared.log.error(f'Image processing unknown mode: {mode}') - image = None - mask = None + if image is not None: image = ImageOps.exif_transpose(image) if selected_scale_tab == 1 and resize_mode != 0: diff --git a/modules/infotext.py b/modules/infotext.py index 05e06e600..baa995c88 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -10,6 +10,7 @@ debug = lambda *args, **kwargs: None # pylint: disable=unnecessary-lambda-assignment re_size = re.compile(r"^(\d+)x(\d+)$") # int x int re_param = re.compile(r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)') # multi-word: value +re_lora = re.compile("= 2.3: torch.cuda._initialization_lock = torch.xpu._initialization_lock torch.cuda._initialized = torch.xpu._initialized torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork @@ -122,7 +122,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.traceback = torch.xpu.traceback # Memory: - if legacy and 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): torch.xpu.empty_cache = lambda: None torch.cuda.empty_cache = torch.xpu.empty_cache @@ -159,7 +159,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd torch.cuda.amp = torch.xpu.amp - if not ipex.__version__.startswith("2.3"): + if float(ipex.__version__[:3]) < 2.3: torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype @@ -178,7 +178,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler # C - if legacy and not ipex.__version__.startswith("2.3"): + if legacy and float(ipex.__version__[:3]) < 2.3: torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count ipex._C._DeviceProperties.major = 12 diff --git a/modules/intel/ipex/diffusers.py b/modules/intel/ipex/diffusers.py index f742fe5c0..5bf5bbe39 100644 --- a/modules/intel/ipex/diffusers.py +++ b/modules/intel/ipex/diffusers.py @@ -1,7 +1,7 @@ import os from functools import wraps, cache import torch -import diffusers #0.29.1 # pylint: disable=import-error +import diffusers # pylint: disable=import-error from diffusers.models.attention_processor import Attention # pylint: disable=protected-access, missing-function-docstring, line-too-long @@ -20,20 +20,31 @@ def fourier_filter(x_in, threshold, scale): # fp64 error -def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: - assert dim % 2 == 0, "The dimension must be even." - - scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim # force fp32 instead of fp64 - omega = 1.0 / (theta**scale) - - batch_size, seq_length = pos.shape - out = torch.einsum("...n,d->...nd", pos, omega) - cos_out = torch.cos(out) - sin_out = torch.sin(out) - - stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) - out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) - return out.float() +class FluxPosEmbed(torch.nn.Module): + def __init__(self, theta: int, axes_dim): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + for i in range(n_axes): + cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=torch.float32, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin @cache @@ -337,4 +348,5 @@ def ipex_diffusers(): if not device_supports_fp64 or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor diffusers.models.attention_processor.AttnProcessor = AttnProcessor - diffusers.models.transformers.transformer_flux.rope = rope + if not device_supports_fp64: + diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed diff --git a/modules/intel/ipex/hijacks.py b/modules/intel/ipex/hijacks.py index 7ec94138d..b1c9a1182 100644 --- a/modules/intel/ipex/hijacks.py +++ b/modules/intel/ipex/hijacks.py @@ -149,6 +149,15 @@ def functional_linear(input, weight, bias=None): bias.data = bias.data.to(dtype=weight.data.dtype) return original_functional_linear(input, weight, bias=bias) +original_functional_conv1d = torch.nn.functional.conv1d +@wraps(torch.nn.functional.conv1d) +def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + original_functional_conv2d = torch.nn.functional.conv2d @wraps(torch.nn.functional.conv2d) def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): @@ -158,6 +167,16 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, bias.data = bias.data.to(dtype=weight.data.dtype) return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) +# LTX Video +original_functional_conv3d = torch.nn.functional.conv3d +@wraps(torch.nn.functional.conv3d) +def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + # SwinIR BF16: original_functional_pad = torch.nn.functional.pad @wraps(torch.nn.functional.pad) @@ -294,7 +313,7 @@ def torch_load(f, map_location=None, *args, **kwargs): # Hijack Functions: def ipex_hijacks(legacy=True): - if legacy: + if legacy and float(torch.__version__[:3]) < 2.5: torch.nn.functional.interpolate = interpolate torch.tensor = torch_tensor torch.Tensor.to = Tensor_to @@ -320,7 +339,9 @@ def ipex_hijacks(legacy=True): torch.nn.functional.group_norm = functional_group_norm torch.nn.functional.layer_norm = functional_layer_norm torch.nn.functional.linear = functional_linear + torch.nn.functional.conv1d = functional_conv1d torch.nn.functional.conv2d = functional_conv2d + torch.nn.functional.conv3d = functional_conv3d torch.nn.functional.pad = functional_pad torch.bmm = torch_bmm diff --git a/modules/ipadapter.py b/modules/ipadapter.py index d5bfbec8c..aa010c33d 100644 --- a/modules/ipadapter.py +++ b/modules/ipadapter.py @@ -9,11 +9,16 @@ import time import json from PIL import Image -from modules import processing, shared, devices, sd_models +import diffusers +import transformers +from modules import processing, shared, devices, sd_models, errors -clip_repo = "h94/IP-Adapter" clip_loaded = None +adapters_loaded = [] +CLIP_ID = "h94/IP-Adapter" +OPEN_ID = "openai/clip-vit-large-patch14" +SIGLIP_ID = 'google/siglip-so400m-patch14-384' ADAPTERS_NONE = { 'None': { 'name': 'none', 'repo': 'none', 'subfolder': 'none' }, } @@ -36,11 +41,13 @@ 'Ostris Composition ViT-H SDXL': { 'name': 'ip_plus_composition_sdxl.safetensors', 'repo': 'ostris/ip-composition-adapter', 'subfolder': '' }, } ADAPTERS_SD3 = { - 'InstantX Large': { 'name': 'ip-adapter.bin', 'repo': 'InstantX/SD3.5-Large-IP-Adapter' }, + 'None': { 'name': 'none', 'repo': 'none', 'subfolder': 'none' }, + 'InstantX Large': { 'name': 'none', 'repo': 'InstantX/SD3.5-Large-IP-Adapter', 'subfolder': 'none', 'revision': 'refs/pr/10' }, } ADAPTERS_F1 = { - 'XLabs AI v1': { 'name': 'ip_adapter.safetensors', 'repo': 'XLabs-AI/flux-ip-adapter' }, - 'XLabs AI v2': { 'name': 'ip_adapter.safetensors', 'repo': 'XLabs-AI/flux-ip-adapter-v2' }, + 'None': { 'name': 'none', 'repo': 'none', 'subfolder': 'none' }, + 'XLabs AI v1': { 'name': 'ip_adapter.safetensors', 'repo': 'XLabs-AI/flux-ip-adapter', 'subfolder': 'none' }, + 'XLabs AI v2': { 'name': 'ip_adapter.safetensors', 'repo': 'XLabs-AI/flux-ip-adapter-v2', 'subfolder': 'none' }, } ADAPTERS = { **ADAPTERS_SD15, **ADAPTERS_SDXL, **ADAPTERS_SD3, **ADAPTERS_F1 } ADAPTERS_ALL = { **ADAPTERS_SD15, **ADAPTERS_SDXL, **ADAPTERS_SD3, **ADAPTERS_F1 } @@ -125,14 +132,27 @@ def crop_images(images, crops): shared.log.error(f'IP adapter: failed to crop image: source={len(images[i])} faces={len(cropped)}') except Exception as e: shared.log.error(f'IP adapter: failed to crop image: {e}') + if shared.sd_model_type == 'sd3' and len(images) == 1: + return images[0] return images -def unapply(pipe): # pylint: disable=arguments-differ +def unapply(pipe, unload: bool = False): # pylint: disable=arguments-differ + if len(adapters_loaded) == 0: + return try: if hasattr(pipe, 'set_ip_adapter_scale'): pipe.set_ip_adapter_scale(0) - if hasattr(pipe, 'unet') and hasattr(pipe.unet, 'config') and pipe.unet.config.encoder_hid_dim_type == 'ip_image_proj': + if unload: + shared.log.debug('IP adapter unload') + pipe.unload_ip_adapter() + if hasattr(pipe, 'unet'): + module = pipe.unet + elif hasattr(pipe, 'transformer'): + module = pipe.transformer + else: + module = None + if module is not None and hasattr(module, 'config') and module.config.encoder_hid_dim_type == 'ip_image_proj': pipe.unet.encoder_hid_proj = None pipe.config.encoder_hid_dim_type = None pipe.unet.set_default_attn_processor() @@ -140,27 +160,78 @@ def unapply(pipe): # pylint: disable=arguments-differ pass -def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapter_scales=[1.0], adapter_crops=[False], adapter_starts=[0.0], adapter_ends=[1.0], adapter_images=[]): +def load_image_encoder(pipe: diffusers.DiffusionPipeline, adapter_names: list[str]): global clip_loaded # pylint: disable=global-statement - # overrides - if hasattr(p, 'ip_adapter_names'): - if isinstance(p.ip_adapter_names, str): - p.ip_adapter_names = [p.ip_adapter_names] - adapters = [ADAPTERS_ALL.get(adapter_name, None) for adapter_name in p.ip_adapter_names if adapter_name is not None and adapter_name.lower() != 'none'] - adapter_names = p.ip_adapter_names - else: - if isinstance(adapter_names, str): - adapter_names = [adapter_names] - adapters = [ADAPTERS.get(adapter, None) for adapter in adapter_names] - adapters = [adapter for adapter in adapters if adapter is not None and adapter['name'].lower() != 'none'] - if len(adapters) == 0: - unapply(pipe) - if hasattr(p, 'ip_adapter_images'): - del p.ip_adapter_images - return False - if shared.sd_model_type not in ['sd', 'sdxl', 'sd3', 'f1']: - shared.log.error(f'IP adapter: model={shared.sd_model_type} class={pipe.__class__.__name__} not supported') - return False + for adapter_name in adapter_names: + # which clip to use + clip_repo = CLIP_ID + if 'ViT' not in adapter_name: # defaults per model + clip_subfolder = 'models/image_encoder' if shared.sd_model_type == 'sd' else 'sdxl_models/image_encoder' + if 'ViT-H' in adapter_name: + clip_subfolder = 'models/image_encoder' # this is vit-h + elif 'ViT-G' in adapter_name: + clip_subfolder = 'sdxl_models/image_encoder' # this is vit-g + else: + if shared.sd_model_type == 'sd': + clip_subfolder = 'models/image_encoder' + elif shared.sd_model_type == 'sdxl': + clip_subfolder = 'sdxl_models/image_encoder' + elif shared.sd_model_type == 'sd3': + clip_repo = SIGLIP_ID + clip_subfolder = None + elif shared.sd_model_type == 'f1': + clip_repo = OPEN_ID + clip_subfolder = None + else: + shared.log.error(f'IP adapter: unknown model type: {adapter_name}') + return False + + # load image encoder used by ip adapter + if pipe.image_encoder is None or clip_loaded != f'{clip_repo}/{clip_subfolder}': + try: + if shared.sd_model_type == 'sd3': + image_encoder = transformers.SiglipVisionModel.from_pretrained(clip_repo, torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir) + else: + if clip_subfolder is None: + image_encoder = transformers.CLIPVisionModelWithProjection.from_pretrained(clip_repo, torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, use_safetensors=True) + shared.log.debug(f'IP adapter load: encoder="{clip_repo}" cls={pipe.image_encoder.__class__.__name__}') + else: + image_encoder = transformers.CLIPVisionModelWithProjection.from_pretrained(clip_repo, subfolder=clip_subfolder, torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, use_safetensors=True) + shared.log.debug(f'IP adapter load: encoder="{clip_repo}/{clip_subfolder}" cls={pipe.image_encoder.__class__.__name__}') + if hasattr(pipe, 'register_modules'): + pipe.register_modules(image_encoder=image_encoder) + else: + pipe.image_encoder = image_encoder + clip_loaded = f'{clip_repo}/{clip_subfolder}' + except Exception as e: + shared.log.error(f'IP adapter load: encoder="{clip_repo}/{clip_subfolder}" {e}') + errors.display(e, 'IP adapter: type=encoder') + return False + sd_models.move_model(pipe.image_encoder, devices.device) + return True + + +def load_feature_extractor(pipe): + # load feature extractor used by ip adapter + if pipe.feature_extractor is None: + try: + if shared.sd_model_type == 'sd3': + feature_extractor = transformers.SiglipImageProcessor.from_pretrained(SIGLIP_ID, torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir) + else: + feature_extractor = transformers.CLIPImageProcessor() + if hasattr(pipe, 'register_modules'): + pipe.register_modules(feature_extractor=feature_extractor) + else: + pipe.feature_extractor = feature_extractor + shared.log.debug(f'IP adapter load: extractor={pipe.feature_extractor.__class__.__name__}') + except Exception as e: + shared.log.error(f'IP adapter load: extractor {e}') + errors.display(e, 'IP adapter: type=extractor') + return False + return True + + +def parse_params(p: processing.StableDiffusionProcessing, adapters: list, adapter_scales: list[float], adapter_crops: list[bool], adapter_starts: list[float], adapter_ends: list[float], adapter_images: list): if hasattr(p, 'ip_adapter_scales'): adapter_scales = p.ip_adapter_scales if hasattr(p, 'ip_adapter_crops'): @@ -201,6 +272,33 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt p.ip_adapter_starts = adapter_starts.copy() adapter_ends = get_scales(adapter_ends, adapter_images) p.ip_adapter_ends = adapter_ends.copy() + return adapter_images, adapter_masks, adapter_scales, adapter_crops, adapter_starts, adapter_ends + + +def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapter_scales=[1.0], adapter_crops=[False], adapter_starts=[0.0], adapter_ends=[1.0], adapter_images=[]): + global adapters_loaded # pylint: disable=global-statement + # overrides + if hasattr(p, 'ip_adapter_names'): + if isinstance(p.ip_adapter_names, str): + p.ip_adapter_names = [p.ip_adapter_names] + adapters = [ADAPTERS_ALL.get(adapter_name, None) for adapter_name in p.ip_adapter_names if adapter_name is not None and adapter_name.lower() != 'none'] + adapter_names = p.ip_adapter_names + else: + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + adapters = [ADAPTERS.get(adapter_name, None) for adapter_name in adapter_names if adapter_name.lower() != 'none'] + + if len(adapters) == 0: + unapply(pipe, getattr(p, 'ip_adapter_unload', False)) + if hasattr(p, 'ip_adapter_images'): + del p.ip_adapter_images + return False + if shared.sd_model_type not in ['sd', 'sdxl', 'sd3', 'f1']: + shared.log.error(f'IP adapter: model={shared.sd_model_type} class={pipe.__class__.__name__} not supported') + return False + + adapter_images, adapter_masks, adapter_scales, adapter_crops, adapter_starts, adapter_ends = parse_params(p, adapters, adapter_scales, adapter_crops, adapter_starts, adapter_ends, adapter_images) + # init code if pipe is None: return False @@ -211,7 +309,7 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt shared.log.error('IP adapter: no image provided') adapters = [] # unload adapter if previously loaded as it will cause runtime errors if len(adapters) == 0: - unapply(pipe) + unapply(pipe, getattr(p, 'ip_adapter_unload', False)) if hasattr(p, 'ip_adapter_images'): del p.ip_adapter_images return False @@ -219,61 +317,30 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt shared.log.error(f'IP adapter: pipeline not supported: {pipe.__class__.__name__}') return False - for adapter_name in adapter_names: - # which clip to use - if 'ViT' not in adapter_name: # defaults per model - if shared.sd_model_type == 'sd': - clip_subfolder = 'models/image_encoder' - else: - clip_subfolder = 'sdxl_models/image_encoder' - if 'ViT-H' in adapter_name: - clip_subfolder = 'models/image_encoder' # this is vit-h - elif 'ViT-G' in adapter_name: - clip_subfolder = 'sdxl_models/image_encoder' # this is vit-g - else: - if shared.sd_model_type == 'sd': - clip_subfolder = 'models/image_encoder' - elif shared.sd_model_type == 'sdxl': - clip_subfolder = 'sdxl_models/image_encoder' - elif shared.sd_model_type == 'sd3': - shared.log.error(f'IP adapter: adapter={adapter_name} type={shared.sd_model_type} cls={shared.sd_model.__class__.__name__}: unsupported base model') - return False - elif shared.sd_model_type == 'f1': - shared.log.error(f'IP adapter: adapter={adapter_name} type={shared.sd_model_type} cls={shared.sd_model.__class__.__name__}: unsupported base model') - return False - else: - shared.log.error(f'IP adapter: unknown model type: {adapter_name}') - return False - - # load feature extractor used by ip adapter - if pipe.feature_extractor is None: - try: - from transformers import CLIPImageProcessor - shared.log.debug('IP adapter load: feature extractor') - pipe.feature_extractor = CLIPImageProcessor() - except Exception as e: - shared.log.error(f'IP adapter load: feature extractor {e}') - return False + if not load_image_encoder(pipe, adapter_names): + return False - # load image encoder used by ip adapter - if pipe.image_encoder is None or clip_loaded != f'{clip_repo}/{clip_subfolder}': - try: - from transformers import CLIPVisionModelWithProjection - shared.log.debug(f'IP adapter load: image encoder="{clip_repo}/{clip_subfolder}"') - pipe.image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_repo, subfolder=clip_subfolder, torch_dtype=devices.dtype, cache_dir=shared.opts.diffusers_dir, use_safetensors=True) - clip_loaded = f'{clip_repo}/{clip_subfolder}' - except Exception as e: - shared.log.error(f'IP adapter load: image encoder="{clip_repo}/{clip_subfolder}" {e}') - return False - sd_models.move_model(pipe.image_encoder, devices.device) + if not load_feature_extractor(pipe): + return False # main code try: t0 = time.time() - repos = [adapter['repo'] for adapter in adapters] - subfolders = [adapter['subfolder'] for adapter in adapters] - names = [adapter['name'] for adapter in adapters] - pipe.load_ip_adapter(repos, subfolder=subfolders, weight_name=names) + repos = [adapter.get('repo', None) for adapter in adapters if adapter.get('repo', 'none') != 'none'] + subfolders = [adapter.get('subfolder', None) for adapter in adapters if adapter.get('subfolder', 'none') != 'none'] + names = [adapter.get('name', None) for adapter in adapters if adapter.get('name', 'none') != 'none'] + revisions = [adapter.get('revision', None) for adapter in adapters if adapter.get('revision', 'none') != 'none'] + kwargs = {} + if len(repos) == 1: + repos = repos[0] + if len(subfolders) > 0: + kwargs['subfolder'] = subfolders if len(subfolders) > 1 else subfolders[0] + if len(names) > 0: + kwargs['weight_name'] = names if len(names) > 1 else names[0] + if len(revisions) > 0: + kwargs['revision'] = revisions[0] + pipe.load_ip_adapter(repos, **kwargs) + adapters_loaded = names if hasattr(p, 'ip_adapter_layers'): pipe.set_ip_adapter_scale(p.ip_adapter_layers) ip_str = ';'.join(adapter_names) + ':' + json.dumps(p.ip_adapter_layers) @@ -281,8 +348,8 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt for i in range(len(adapter_scales)): if adapter_starts[i] > 0: adapter_scales[i] = 0.00 - pipe.set_ip_adapter_scale(adapter_scales) - ip_str = [f'{os.path.splitext(adapter)[0]}:{scale}:{start}:{end}' for adapter, scale, start, end in zip(adapter_names, adapter_scales, adapter_starts, adapter_ends)] + pipe.set_ip_adapter_scale(adapter_scales if len(adapter_scales) > 1 else adapter_scales[0]) + ip_str = [f'{os.path.splitext(adapter)[0]}:{scale}:{start}:{end}:{crop}' for adapter, scale, start, end, crop in zip(adapter_names, adapter_scales, adapter_starts, adapter_ends, adapter_crops)] p.task_args['ip_adapter_image'] = crop_images(adapter_images, adapter_crops) if len(adapter_masks) > 0: p.cross_attention_kwargs = { 'ip_adapter_masks': adapter_masks } @@ -291,4 +358,5 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt shared.log.info(f'IP adapter: {ip_str} image={adapter_images} mask={adapter_masks is not None} time={t1-t0:.2f}') except Exception as e: shared.log.error(f'IP adapter load: adapters={adapter_names} repo={repos} folders={subfolders} names={names} {e}') + errors.display(e, 'IP adapter: type=adapter') return True diff --git a/modules/loader.py b/modules/loader.py index cd51cc8eb..38d942fd4 100644 --- a/modules/loader.py +++ b/modules/loader.py @@ -14,7 +14,7 @@ logging.getLogger("DeepSpeed").disabled = True -os.environ.setdefault('TORCH_LOGS', '-all') +# os.environ.setdefault('TORCH_LOGS', '-all') import torch # pylint: disable=C0411 if torch.__version__.startswith('2.5.0'): errors.log.warning(f'Disabling cuDNN for SDP on torch={torch.__version__}') diff --git a/modules/lora/extra_networks_lora.py b/modules/lora/extra_networks_lora.py new file mode 100644 index 000000000..42c4a92f6 --- /dev/null +++ b/modules/lora/extra_networks_lora.py @@ -0,0 +1,152 @@ +import re +import numpy as np +import modules.lora.networks as networks +from modules import extra_networks, shared + + +# from https://github.com/cheald/sd-webui-loractl/blob/master/loractl/lib/utils.py +def get_stepwise(param, step, steps): + def sorted_positions(raw_steps): + steps = [[float(s.strip()) for s in re.split("[@~]", x)] + for x in re.split("[,;]", str(raw_steps))] + if len(steps[0]) == 1: # If we just got a single number, just return it + return steps[0][0] + steps = [[s[0], s[1] if len(s) == 2 else 1] for s in steps] # Add implicit 1s to any steps which don't have a weight + steps.sort(key=lambda k: k[1]) # Sort by index + steps = [list(v) for v in zip(*steps)] + return steps + + def calculate_weight(m, step, max_steps, step_offset=2): + if isinstance(m, list): + if m[1][-1] <= 1.0: + step = step / (max_steps - step_offset) if max_steps > 0 else 1.0 + v = np.interp(step, m[1], m[0]) + return v + else: + return m + + stepwise = calculate_weight(sorted_positions(param), step, steps) + return stepwise + + +def prompt(p): + if shared.opts.lora_apply_tags == 0: + return + all_tags = [] + for loaded in networks.loaded_networks: + page = [en for en in shared.extra_networks if en.name == 'lora'][0] + item = page.create_item(loaded.name) + tags = (item or {}).get("tags", {}) + loaded.tags = list(tags) + if len(loaded.tags) == 0: + loaded.tags.append(loaded.name) + if shared.opts.lora_apply_tags > 0: + loaded.tags = loaded.tags[:shared.opts.lora_apply_tags] + all_tags.extend(loaded.tags) + if len(all_tags) > 0: + all_tags = list(set(all_tags)) + all_tags = [t for t in all_tags if t not in p.prompt] + shared.log.debug(f"Load network: type=LoRA tags={all_tags} max={shared.opts.lora_apply_tags} apply") + all_tags = ', '.join(all_tags) + p.extra_generation_params["LoRA tags"] = all_tags + if '_tags_' in p.prompt: + p.prompt = p.prompt.replace('_tags_', all_tags) + else: + p.prompt = f"{p.prompt}, {all_tags}" + if p.all_prompts is not None: + for i in range(len(p.all_prompts)): + if '_tags_' in p.all_prompts[i]: + p.all_prompts[i] = p.all_prompts[i].replace('_tags_', all_tags) + else: + p.all_prompts[i] = f"{p.all_prompts[i]}, {all_tags}" + + +def infotext(p): + names = [i.name for i in networks.loaded_networks] + if len(names) > 0: + p.extra_generation_params["LoRA networks"] = ", ".join(names) + if shared.opts.lora_add_hashes_to_infotext: + network_hashes = [] + for item in networks.loaded_networks: + if not item.network_on_disk.shorthash: + continue + network_hashes.append(item.network_on_disk.shorthash) + if len(network_hashes) > 0: + p.extra_generation_params["LoRA hashes"] = ", ".join(network_hashes) + + +def parse(p, params_list, step=0): + names = [] + te_multipliers = [] + unet_multipliers = [] + dyn_dims = [] + for params in params_list: + assert params.items + names.append(params.positional[0]) + te_multiplier = params.named.get("te", params.positional[1] if len(params.positional) > 1 else shared.opts.extra_networks_default_multiplier) + if isinstance(te_multiplier, str) and "@" in te_multiplier: + te_multiplier = get_stepwise(te_multiplier, step, p.steps) + else: + te_multiplier = float(te_multiplier) + unet_multiplier = [params.positional[2] if len(params.positional) > 2 else te_multiplier] * 3 + unet_multiplier = [params.named.get("unet", unet_multiplier[0])] * 3 + unet_multiplier[0] = params.named.get("in", unet_multiplier[0]) + unet_multiplier[1] = params.named.get("mid", unet_multiplier[1]) + unet_multiplier[2] = params.named.get("out", unet_multiplier[2]) + for i in range(len(unet_multiplier)): + if isinstance(unet_multiplier[i], str) and "@" in unet_multiplier[i]: + unet_multiplier[i] = get_stepwise(unet_multiplier[i], step, p.steps) + else: + unet_multiplier[i] = float(unet_multiplier[i]) + dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None + dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim + te_multipliers.append(te_multiplier) + unet_multipliers.append(unet_multiplier) + dyn_dims.append(dyn_dim) + return names, te_multipliers, unet_multipliers, dyn_dims + + +class ExtraNetworkLora(extra_networks.ExtraNetwork): + + def __init__(self): + super().__init__('lora') + self.active = False + self.model = None + self.errors = {} + + def activate(self, p, params_list, step=0, include=[], exclude=[]): + self.errors.clear() + if self.active: + if self.model != shared.opts.sd_model_checkpoint: # reset if model changed + self.active = False + if len(params_list) > 0 and not self.active: # activate patches once + # shared.log.debug(f'Activate network: type=LoRA model="{shared.opts.sd_model_checkpoint}"') + self.active = True + self.model = shared.opts.sd_model_checkpoint + if 'text_encoder' in include: + networks.timer.clear(complete=True) + names, te_multipliers, unet_multipliers, dyn_dims = parse(p, params_list, step) + networks.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load + networks.network_activate(include, exclude) + if len(networks.loaded_networks) > 0 and len(networks.applied_layers) > 0 and step == 0: + infotext(p) + prompt(p) + shared.log.info(f'Load network: type=LoRA apply={[n.name for n in networks.loaded_networks]} mode={"fuse" if shared.opts.lora_fuse_diffusers else "backup"} te={te_multipliers} unet={unet_multipliers} time={networks.timer.summary}') + + def deactivate(self, p): + if shared.native and len(networks.diffuser_loaded) > 0: + if hasattr(shared.sd_model, "unload_lora_weights") and hasattr(shared.sd_model, "text_encoder"): + if not (shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled is True): + try: + if shared.opts.lora_fuse_diffusers: + shared.sd_model.unfuse_lora() + shared.sd_model.unload_lora_weights() # fails for non-CLIP models + except Exception: + pass + networks.network_deactivate() + if self.active and networks.debug: + shared.log.debug(f"Network end: type=LoRA time={networks.timer.summary}") + if self.errors: + for k, v in self.errors.items(): + shared.log.error(f'LoRA: name="{k}" errors={v}') + self.errors.clear() diff --git a/modules/lora/lora_convert.py b/modules/lora/lora_convert.py new file mode 100644 index 000000000..032ffa5a3 --- /dev/null +++ b/modules/lora/lora_convert.py @@ -0,0 +1,509 @@ +import os +import re +import bisect +from typing import Dict +import torch +from modules import shared + + +debug = os.environ.get('SD_LORA_DEBUG', None) is not None +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "norm1": "in_layers_0", + "norm2": "out_layers_0", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} +re_digits = re.compile(r"\d+") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + + +def make_unet_conversion_map() -> Dict[str, str]: + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3 * i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2 * j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j + 1}." + sd_time_embed_prefix = f"time_embed.{j * 2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j + 1}." + sd_label_embed_prefix = f"label_emb.0.{j * 2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map} + return sd_hf_conversion_map + + +class KeyConvert: + def __init__(self): + self.is_sdxl = True if shared.sd_model_type == "sdxl" else False + self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None + self.LORA_PREFIX_UNET = "lora_unet_" + self.LORA_PREFIX_TEXT_ENCODER = "lora_te_" + self.OFT_PREFIX_UNET = "oft_unet_" + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER + self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1_" + self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2_" + + def __call__(self, key): + if self.is_sdxl: + if "diffusion_model" in key: # Fix NTC Slider naming error + key = key.replace("diffusion_model", "lora_unet") + map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules + map_keys.sort() + search_key = key.replace(self.LORA_PREFIX_UNET, "").replace(self.OFT_PREFIX_UNET, "").replace(self.LORA_PREFIX_TEXT_ENCODER1, "").replace(self.LORA_PREFIX_TEXT_ENCODER2, "") + position = bisect.bisect_right(map_keys, search_key) + map_key = map_keys[position - 1] + if search_key.startswith(map_key): + key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]).replace("oft", "lora") # pylint: disable=unsubscriptable-object + if "lycoris" in key and "transformer" in key: + key = key.replace("lycoris", "lora_transformer") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + if sd_module is None: + sd_module = shared.sd_model.network_layer_mapping.get(key.replace("guidance", "timestep"), None) # FLUX1 fix + if debug and sd_module is None: + raise RuntimeError(f"LoRA key not found in network_layer_mapping: key={key} mapping={shared.sd_model.network_layer_mapping.keys()}") + return key, sd_module + + +# Taken from https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py +# Modified from 'lora_A' and 'lora_B' to 'lora_down' and 'lora_up' +# Added early exit +# The utilities under `_convert_kohya_flux_lora_to_diffusers()` +# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py +# All credits go to `kohya-ss`. +def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + ait_sd[ait_key + ".lora_down.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_up.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + +def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] + + # scale weight by alpha and dim + alpha = sds_sd.pop(sds_key + ".alpha") + scale = alpha / sd_lora_rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # check upweight is sparse or not + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all( + up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0 + ) + i += dims[j] + # if is_sparse: + # print(f"weight is sparse: {sds_key}") + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_down.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_up.weight" for k in ait_keys] + if not is_sparse: + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension + else: + # down_weight is chunked to each split + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] + +def _convert_text_encoder_lora_key(key, lora_name): + """ + Converts a text encoder LoRA key to a Diffusers compatible key. + """ + if lora_name.startswith(("lora_te_", "lora_te1_")): + key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" + else: + key_to_replace = "lora_te2_" + + diffusers_name = key.replace(key_to_replace, "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("text.projection", "text_projection") + + if "self_attn" in diffusers_name or "text_projection" in diffusers_name: + pass + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + return diffusers_name + +def _convert_kohya_flux_lora_to_diffusers(state_dict): + def _convert_sd_scripts_to_ai_toolkit(sds_sd): + ait_sd = {} + for i in range(19): + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_out.0", + ) + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_0", + f"transformer.transformer_blocks.{i}.ff.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mlp_2", + f"transformer.transformer_blocks.{i}.ff.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_mod_lin", + f"transformer.transformer_blocks.{i}.norm1.linear", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_add_out", + ) + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_0", + f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mlp_2", + f"transformer.transformer_blocks.{i}.ff_context.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_mod_lin", + f"transformer.transformer_blocks.{i}.norm1_context.linear", + ) + + for i in range(38): + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear2", + f"transformer.single_transformer_blocks.{i}.proj_out", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_modulation_lin", + f"transformer.single_transformer_blocks.{i}.norm.linear", + ) + + if len(sds_sd) > 0: + return None + + return ait_sd + + return _convert_sd_scripts_to_ai_toolkit(state_dict) + +def _convert_kohya_sd3_lora_to_diffusers(state_dict): + def _convert_sd_scripts_to_ai_toolkit(sds_sd): + ait_sd = {} + for i in range(38): + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_joint_blocks_{i}_context_block_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_joint_blocks_{i}_context_block_mlp_fc1", + f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_joint_blocks_{i}_context_block_mlp_fc2", + f"transformer.transformer_blocks.{i}.ff_context.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_joint_blocks_{i}_x_block_mlp_fc1", + f"transformer.transformer_blocks.{i}.ff.net.0.proj", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_joint_blocks_{i}_x_block_mlp_fc2", + f"transformer.transformer_blocks.{i}.ff.net.2", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_joint_blocks_{i}_context_block_adaLN_modulation_1", + f"transformer.transformer_blocks.{i}.norm1_context.linear", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_joint_blocks_{i}_x_block_adaLN_modulation_1", + f"transformer.transformer_blocks.{i}.norm1.linear", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_joint_blocks_{i}_context_block_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_add_out", + ) + _convert_to_ai_toolkit( + sds_sd, + ait_sd, + f"lora_unet_joint_blocks_{i}_x_block_attn_proj", + f"transformer.transformer_blocks.{i}.attn.to_out_0", + ) + + _convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_joint_blocks_{i}_x_block_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + remaining_keys = list(sds_sd.keys()) + te_state_dict = {} + if remaining_keys: + if not all(k.startswith("lora_te1") for k in remaining_keys): + raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}") + for key in remaining_keys: + if not key.endswith("lora_down.weight"): + continue + + lora_name = key.split(".")[0] + lora_name_up = f"{lora_name}.lora_up.weight" + lora_name_alpha = f"{lora_name}.alpha" + diffusers_name = _convert_text_encoder_lora_key(key, lora_name) + + sd_lora_rank = 1 + if lora_name.startswith(("lora_te_", "lora_te1_")): + down_weight = sds_sd.pop(key) + sd_lora_rank = down_weight.shape[0] + te_state_dict[diffusers_name] = down_weight + te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up) + + if lora_name_alpha in sds_sd: + alpha = sds_sd.pop(lora_name_alpha).item() + scale = alpha / sd_lora_rank + + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + te_state_dict[diffusers_name] *= scale_down + te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up + + if len(sds_sd) > 0: + print(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}") + + if te_state_dict: + te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()} + + new_state_dict = {**ait_sd, **te_state_dict} + return new_state_dict + + return _convert_sd_scripts_to_ai_toolkit(state_dict) + + +def assign_network_names_to_compvis_modules(sd_model): + if sd_model is None: + return + sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility + network_layer_mapping = {} + if hasattr(sd_model, 'text_encoder') and sd_model.text_encoder is not None: + for name, module in sd_model.text_encoder.named_modules(): + prefix = "lora_te1_" if hasattr(sd_model, 'text_encoder_2') else "lora_te_" + network_name = prefix + name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + if hasattr(sd_model, 'text_encoder_2'): + for name, module in sd_model.text_encoder_2.named_modules(): + network_name = "lora_te2_" + name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + if hasattr(sd_model, 'unet'): + for name, module in sd_model.unet.named_modules(): + network_name = "lora_unet_" + name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + if hasattr(sd_model, 'transformer'): + for name, module in sd_model.transformer.named_modules(): + network_name = "lora_transformer_" + name.replace(".", "_") + network_layer_mapping[network_name] = module + if "norm" in network_name and "linear" not in network_name and shared.sd_model_type != "sd3": + continue + module.network_layer_name = network_name + shared.sd_model.network_layer_mapping = network_layer_mapping diff --git a/modules/lora/lora_extract.py b/modules/lora/lora_extract.py new file mode 100644 index 000000000..58cd065bb --- /dev/null +++ b/modules/lora/lora_extract.py @@ -0,0 +1,271 @@ +import os +import time +import json +import datetime +import torch +from safetensors.torch import save_file +import gradio as gr +from rich import progress as p +from modules import shared, devices +from modules.ui_common import create_refresh_button +from modules.call_queue import wrap_gradio_gpu_call + + +class SVDHandler: + def __init__(self, maxrank=0, rank_ratio=1): + self.network_name: str = None + self.U: torch.Tensor = None + self.S: torch.Tensor = None + self.Vh: torch.Tensor = None + self.maxrank: int = maxrank + self.rank_ratio: float = rank_ratio + self.rank: int = 0 + self.out_size: int = None + self.in_size: int = None + self.kernel_size: tuple[int, int] = None + self.conv2d: bool = False + + def decompose(self, weight, backupweight): + self.conv2d = len(weight.size()) == 4 + self.kernel_size = None if not self.conv2d else weight.size()[2:4] + self.out_size, self.in_size = weight.size()[0:2] + diffweight = weight.clone().to(devices.device) + diffweight -= backupweight.to(devices.device) + if self.conv2d: + if self.conv2d and self.kernel_size != (1, 1): + diffweight = diffweight.flatten(start_dim=1) + else: + diffweight = diffweight.squeeze() + self.U, self.S, self.Vh = torch.svd_lowrank(diffweight.to(device=devices.device, dtype=torch.float), self.maxrank, 2) + # del diffweight + self.U = self.U.to(device=devices.cpu, dtype=torch.bfloat16) + self.S = self.S.to(device=devices.cpu, dtype=torch.bfloat16) + self.Vh = self.Vh.t().to(device=devices.cpu, dtype=torch.bfloat16) # svd_lowrank outputs a transposed matrix + + def findrank(self): + if self.rank_ratio < 1: + S_squared = self.S.pow(2) + S_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq + index = int(torch.searchsorted(sum_S_squared, self.rank_ratio ** 2)) + 1 + index = max(1, min(index, len(self.S) - 1)) + self.rank = index + if self.maxrank > 0: + self.rank = min(self.rank, self.maxrank) + else: + self.rank = min(self.in_size, self.out_size, self.maxrank) + + def makeweights(self): + self.findrank() + up = self.U[:, :self.rank] @ torch.diag(self.S[:self.rank]) + down = self.Vh[:self.rank, :] + if self.conv2d and self.kernel_size is not None: + up = up.reshape(self.out_size, self.rank, 1, 1) + down = down.reshape(self.rank, self.in_size, self.kernel_size[0], self.kernel_size[1]) # pylint: disable=unsubscriptable-object + return_dict = {f'{self.network_name}.lora_up.weight': up.contiguous(), + f'{self.network_name}.lora_down.weight': down.contiguous(), + f'{self.network_name}.alpha': torch.tensor(down.shape[0]), + } + return return_dict + + +def loaded_lora(): + if not shared.sd_loaded: + return "" + loaded = set() + if hasattr(shared.sd_model, 'unet'): + for _name, module in shared.sd_model.unet.named_modules(): + current = getattr(module, "network_current_names", None) + if current is not None: + current = [item[0] for item in current] + loaded.update(current) + return list(loaded) + + +def loaded_lora_str(): + return ", ".join(loaded_lora()) + + +def make_meta(fn, maxrank, rank_ratio): + meta = { + "model_spec.sai_model_spec": "1.0.0", + "model_spec.title": os.path.splitext(os.path.basename(fn))[0], + "model_spec.author": "SD.Next", + "model_spec.implementation": "https://github.com/vladmandic/automatic", + "model_spec.date": datetime.datetime.now().astimezone().replace(microsecond=0).isoformat(), + "model_spec.base_model": shared.opts.sd_model_checkpoint, + "model_spec.dtype": str(devices.dtype), + "model_spec.base_lora": json.dumps(loaded_lora()), + "model_spec.config": f"maxrank={maxrank} rank_ratio={rank_ratio}", + } + if shared.sd_model_type == "sdxl": + meta["model_spec.architecture"] = "stable-diffusion-xl-v1-base/lora" # sai standard + meta["ss_base_model_version"] = "sdxl_base_v1-0" # kohya standard + elif shared.sd_model_type == "sd": + meta["model_spec.architecture"] = "stable-diffusion-v1/lora" + meta["ss_base_model_version"] = "sd_v1" + elif shared.sd_model_type == "f1": + meta["model_spec.architecture"] = "flux-1-dev/lora" + meta["ss_base_model_version"] = "flux1" + elif shared.sd_model_type == "sc": + meta["model_spec.architecture"] = "stable-cascade-v1-prior/lora" + return meta + + +def make_lora(fn, maxrank, auto_rank, rank_ratio, modules, overwrite): + if not shared.sd_loaded or not shared.native: + msg = "LoRA extract: model not loaded" + shared.log.warning(msg) + yield msg + return + if loaded_lora() == "": + msg = "LoRA extract: no LoRA detected" + shared.log.warning(msg) + yield msg + return + if not fn: + msg = "LoRA extract: target filename required" + shared.log.warning(msg) + yield msg + return + t0 = time.time() + maxrank = int(maxrank) + rank_ratio = 1 if not auto_rank else rank_ratio + shared.log.debug(f'LoRA extract: modules={modules} maxrank={maxrank} auto={auto_rank} ratio={rank_ratio} fn="{fn}"') + shared.state.begin('LoRA extract') + + with p.Progress(p.TextColumn('[cyan]LoRA extract'), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TextColumn('[cyan]{task.description}'), console=shared.console) as progress: + + if 'te' in modules and getattr(shared.sd_model, 'text_encoder', None) is not None: + modules = shared.sd_model.text_encoder.named_modules() + task = progress.add_task(description="te1 decompose", total=len(list(modules))) + for name, module in shared.sd_model.text_encoder.named_modules(): + progress.update(task, advance=1) + weights_backup = getattr(module, "network_weights_backup", None) + if weights_backup is None or getattr(module, "network_current_names", None) is None: + continue + prefix = "lora_te1_" if hasattr(shared.sd_model, 'text_encoder_2') else "lora_te_" + module.svdhandler = SVDHandler(maxrank, rank_ratio) + module.svdhandler.network_name = prefix + name.replace(".", "_") + with devices.inference_context(): + module.svdhandler.decompose(module.weight, weights_backup) + progress.remove_task(task) + t1 = time.time() + + if 'te' in modules and getattr(shared.sd_model, 'text_encoder_2', None) is not None: + modules = shared.sd_model.text_encoder_2.named_modules() + task = progress.add_task(description="te2 decompose", total=len(list(modules))) + for name, module in shared.sd_model.text_encoder_2.named_modules(): + progress.update(task, advance=1) + weights_backup = getattr(module, "network_weights_backup", None) + if weights_backup is None or getattr(module, "network_current_names", None) is None: + continue + module.svdhandler = SVDHandler(maxrank, rank_ratio) + module.svdhandler.network_name = "lora_te2_" + name.replace(".", "_") + with devices.inference_context(): + module.svdhandler.decompose(module.weight, weights_backup) + progress.remove_task(task) + t2 = time.time() + + if 'unet' in modules and getattr(shared.sd_model, 'unet', None) is not None: + modules = shared.sd_model.unet.named_modules() + task = progress.add_task(description="unet decompose", total=len(list(modules))) + for name, module in shared.sd_model.unet.named_modules(): + progress.update(task, advance=1) + weights_backup = getattr(module, "network_weights_backup", None) + if weights_backup is None or getattr(module, "network_current_names", None) is None: + continue + module.svdhandler = SVDHandler(maxrank, rank_ratio) + module.svdhandler.network_name = "lora_unet_" + name.replace(".", "_") + with devices.inference_context(): + module.svdhandler.decompose(module.weight, weights_backup) + progress.remove_task(task) + t3 = time.time() + + # TODO: lora make support quantized flux + # if 'te' in modules and getattr(shared.sd_model, 'transformer', None) is not None: + # for name, module in shared.sd_model.transformer.named_modules(): + # if "norm" in name and "linear" not in name: + # continue + # weights_backup = getattr(module, "network_weights_backup", None) + # if weights_backup is None: + # continue + # module.svdhandler = SVDHandler() + # module.svdhandler.network_name = "lora_transformer_" + name.replace(".", "_") + # module.svdhandler.decompose(module.weight, weights_backup) + # module.svdhandler.findrank(rank, rank_ratio) + + lora_state_dict = {} + for sub in ['text_encoder', 'text_encoder_2', 'unet', 'transformer']: + submodel = getattr(shared.sd_model, sub, None) + if submodel is not None: + modules = submodel.named_modules() + task = progress.add_task(description=f"{sub} exctract", total=len(list(modules))) + for _name, module in submodel.named_modules(): + progress.update(task, advance=1) + if not hasattr(module, "svdhandler"): + continue + lora_state_dict.update(module.svdhandler.makeweights()) + del module.svdhandler + progress.remove_task(task) + t4 = time.time() + + if not os.path.isabs(fn): + fn = os.path.join(shared.cmd_opts.lora_dir, fn) + if not fn.endswith('.safetensors'): + fn += '.safetensors' + if os.path.exists(fn): + if overwrite: + os.remove(fn) + else: + msg = f'LoRA extract: fn="{fn}" file exists' + shared.log.warning(msg) + yield msg + return + + shared.state.end() + meta = make_meta(fn, maxrank, rank_ratio) + shared.log.debug(f'LoRA metadata: {meta}') + try: + save_file(tensors=lora_state_dict, metadata=meta, filename=fn) + except Exception as e: + msg = f'LoRA extract error: fn="{fn}" {e}' + shared.log.error(msg) + yield msg + return + t5 = time.time() + shared.log.debug(f'LoRA extract: time={t5-t0:.2f} te1={t1-t0:.2f} te2={t2-t1:.2f} unet={t3-t2:.2f} save={t5-t4:.2f}') + keys = list(lora_state_dict.keys()) + msg = f'LoRA extract: fn="{fn}" keys={len(keys)}' + shared.log.info(msg) + yield msg + + +def create_ui(): + def gr_show(visible=True): + return {"visible": visible, "__type__": "update"} + + with gr.Tab(label="Extract LoRA"): + with gr.Row(): + loaded = gr.Textbox(placeholder="Press refresh to query loaded LoRA", label="Loaded LoRA", interactive=False) + create_refresh_button(loaded, lambda: None, lambda: {'value': loaded_lora_str()}, "testid") + with gr.Group(): + with gr.Row(): + modules = gr.CheckboxGroup(label="Modules to extract", value=['unet'], choices=['te', 'unet']) + with gr.Row(): + auto_rank = gr.Checkbox(value=False, label="Automatically determine rank") + rank_ratio = gr.Slider(label="Autorank ratio", value=1, minimum=0, maximum=1, step=0.05, visible=False) + rank = gr.Slider(label="Maximum rank", value=32, minimum=1, maximum=256) + with gr.Row(): + filename = gr.Textbox(label="LoRA target filename") + overwrite = gr.Checkbox(value=False, label="Overwrite existing file") + with gr.Row(): + extract = gr.Button(value="Extract LoRA", variant='primary') + status = gr.HTML(value="", show_label=False) + + auto_rank.change(fn=lambda x: gr_show(x), inputs=[auto_rank], outputs=[rank_ratio]) + extract.click( + fn=wrap_gradio_gpu_call(make_lora, extra_outputs=[]), + inputs=[filename, rank, auto_rank, rank_ratio, modules, overwrite], + outputs=[status] + ) diff --git a/modules/lora/lora_timers.py b/modules/lora/lora_timers.py new file mode 100644 index 000000000..30c35a728 --- /dev/null +++ b/modules/lora/lora_timers.py @@ -0,0 +1,38 @@ +class Timer(): + list: float = 0 + load: float = 0 + backup: float = 0 + calc: float = 0 + apply: float = 0 + move: float = 0 + restore: float = 0 + activate: float = 0 + deactivate: float = 0 + + @property + def total(self): + return round(self.activate + self.deactivate, 2) + + @property + def summary(self): + t = {} + for k, v in self.__dict__.items(): + if v > 0.1: + t[k] = round(v, 2) + return t + + def clear(self, complete: bool = False): + self.backup = 0 + self.calc = 0 + self.apply = 0 + self.move = 0 + self.restore = 0 + if complete: + self.activate = 0 + self.deactivate = 0 + + def add(self, name, t): + self.__dict__[name] += t + + def __str__(self): + return f'{self.__class__.__name__}({self.summary})' diff --git a/modules/lora/lyco_helpers.py b/modules/lora/lyco_helpers.py new file mode 100644 index 000000000..ac4f2419f --- /dev/null +++ b/modules/lora/lyco_helpers.py @@ -0,0 +1,66 @@ +import torch + + +def make_weight_cp(t, wa, wb): + temp = torch.einsum('i j k l, j r -> i r k l', t, wb) + return torch.einsum('i j k l, i r -> r j k l', temp, wa) + + +def rebuild_conventional(up, down, shape, dyn_dim=None): + up = up.reshape(up.size(0), -1) + down = down.reshape(down.size(0), -1) + if dyn_dim is not None: + up = up[:, :dyn_dim] + down = down[:dyn_dim, :] + return (up @ down).reshape(shape).to(up.dtype) + + +def rebuild_cp_decomposition(up, down, mid): + up = up.reshape(up.size(0), -1) + down = down.reshape(down.size(0), -1) + return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down).to(up.dtype) + + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + """ + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + """ + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n diff --git a/modules/lora/network.py b/modules/lora/network.py new file mode 100644 index 000000000..97feb76f1 --- /dev/null +++ b/modules/lora/network.py @@ -0,0 +1,188 @@ +import os +import enum +from typing import Union +from collections import namedtuple +from modules import sd_models, hashes, shared + + +NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) +metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} + + +class SdVersion(enum.Enum): + Unknown = 1 + SD1 = 2 + SD2 = 3 + SD3 = 3 + SDXL = 4 + SC = 5 + F1 = 6 + + +class NetworkOnDisk: + def __init__(self, name, filename): + self.shorthash = None + self.hash = None + self.name = name + self.filename = filename + if filename.startswith(shared.cmd_opts.lora_dir): + self.fullname = os.path.splitext(filename[len(shared.cmd_opts.lora_dir):].strip("/"))[0] + else: + self.fullname = name + self.metadata = {} + self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + if self.is_safetensors: + self.metadata = sd_models.read_metadata_from_safetensors(filename) + if self.metadata: + m = {} + for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): + m[k] = v + self.metadata = m + self.alias = self.metadata.get('ss_output_name', self.name) + sha256 = hashes.sha256_from_cache(self.filename, "lora/" + self.name) or hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=True) or self.metadata.get('sshs_model_hash') + self.set_hash(sha256) + self.sd_version = self.detect_version() + + def detect_version(self): + base = str(self.metadata.get('ss_base_model_version', "")).lower() + arch = str(self.metadata.get('modelspec.architecture', "")).lower() + if base.startswith("sd_v1"): + return 'sd1' + if base.startswith("sdxl"): + return 'xl' + if base.startswith("stable_cascade"): + return 'sc' + if base.startswith("sd3"): + return 'sd3' + if base.startswith("flux"): + return 'f1' + + if arch.startswith("stable-diffusion-v1"): + return 'sd1' + if arch.startswith("stable-diffusion-xl"): + return 'xl' + if arch.startswith("stable-cascade"): + return 'sc' + if arch.startswith("flux"): + return 'f1' + + if "v1-5" in str(self.metadata.get('ss_sd_model_name', "")): + return 'sd1' + if str(self.metadata.get('ss_v2', "")) == "True": + return 'sd2' + if 'flux' in self.name.lower(): + return 'f1' + if 'xl' in self.name.lower(): + return 'xl' + + return '' + + def set_hash(self, v): + self.hash = v or '' + self.shorthash = self.hash[0:8] + + def read_hash(self): + if not self.hash: + self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') + + def get_alias(self): + import modules.lora.networks as networks + return self.name if shared.opts.lora_preferred_name == "filename" or self.alias.lower() in networks.forbidden_network_aliases else self.alias + + +class Network: # LoraModule + def __init__(self, name, network_on_disk: NetworkOnDisk): + self.name = name + self.network_on_disk = network_on_disk + self.te_multiplier = 1.0 + self.unet_multiplier = [1.0] * 3 + self.dyn_dim = None + self.modules = {} + self.bundle_embeddings = {} + self.mtime = None + self.mentioned_name = None + self.tags = None + """the text that was used to add the network to prompt - can be either name or an alias""" + + +class ModuleType: + def create_module(self, net: Network, weights: NetworkWeights) -> Union[Network, None]: # pylint: disable=W0613 + return None + + +class NetworkModule: + def __init__(self, net: Network, weights: NetworkWeights): + self.network = net + self.network_key = weights.network_key + self.sd_key = weights.sd_key + self.sd_module = weights.sd_module + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + self.dim = None + self.bias = weights.w.get("bias") + self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None + self.scale = weights.w["scale"].item() if "scale" in weights.w else None + self.dora_scale = weights.w.get("dora_scale", None) + self.dora_norm_dims = len(self.shape) - 1 + + def multiplier(self): + unet_multiplier = 3 * [self.network.unet_multiplier] if not isinstance(self.network.unet_multiplier, list) else self.network.unet_multiplier + if 'transformer' in self.sd_key[:20]: + return self.network.te_multiplier + if "down_blocks" in self.sd_key: + return unet_multiplier[0] + if "mid_block" in self.sd_key: + return unet_multiplier[1] + if "up_blocks" in self.sd_key: + return unet_multiplier[2] + else: + return unet_multiplier[0] + + def calc_scale(self): + if self.scale is not None: + return self.scale + if self.dim is not None and self.alpha is not None: + return self.alpha / self.dim + return 1.0 + + def apply_weight_decompose(self, updown, orig_weight): + # Match the device/dtype + orig_weight = orig_weight.to(updown.dtype) + dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype) + updown = updown.to(orig_weight.device) + + merged_scale1 = updown + orig_weight + merged_scale1_norm = ( + merged_scale1.transpose(0, 1) + .reshape(merged_scale1.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims) + .transpose(0, 1) + ) + + dora_merged = ( + merged_scale1 * (dora_scale / merged_scale1_norm) + ) + final_updown = dora_merged - orig_weight + return final_updown + + def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): + if self.bias is not None: + updown = updown.reshape(self.bias.shape) + updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown = updown.reshape(output_shape) + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + if ex_bias is not None: + ex_bias = ex_bias * self.multiplier() + if self.dora_scale is not None: + updown = self.apply_weight_decompose(updown, orig_weight) + return updown * self.calc_scale() * self.multiplier(), ex_bias + + def calc_updown(self, target): + raise NotImplementedError + + def forward(self, x, y): + raise NotImplementedError diff --git a/modules/lora/network_full.py b/modules/lora/network_full.py new file mode 100644 index 000000000..5eb0b2e4e --- /dev/null +++ b/modules/lora/network_full.py @@ -0,0 +1,26 @@ +import modules.lora.network as network + + +class ModuleTypeFull(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["diff"]): + return NetworkModuleFull(net, weights) + return None + + +class NetworkModuleFull(network.NetworkModule): # pylint: disable=abstract-method + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.weight = weights.w.get("diff") + self.ex_bias = weights.w.get("diff_b") + + def calc_updown(self, target): + output_shape = self.weight.shape + updown = self.weight.to(target.device, dtype=target.dtype) + if self.ex_bias is not None: + ex_bias = self.ex_bias.to(target.device, dtype=target.dtype) + else: + ex_bias = None + + return self.finalize_updown(updown, target, output_shape, ex_bias) diff --git a/modules/lora/network_glora.py b/modules/lora/network_glora.py new file mode 100644 index 000000000..ffcb25986 --- /dev/null +++ b/modules/lora/network_glora.py @@ -0,0 +1,30 @@ +import modules.lora.network as network + + +class ModuleTypeGLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]): + return NetworkModuleGLora(net, weights) + return None + +# adapted from https://github.com/KohakuBlueleaf/LyCORIS +class NetworkModuleGLora(network.NetworkModule): # pylint: disable=abstract-method + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.w1a = weights.w["a1.weight"] + self.w1b = weights.w["b1.weight"] + self.w2a = weights.w["a2.weight"] + self.w2b = weights.w["b2.weight"] + + def calc_updown(self, target): # pylint: disable=arguments-differ + w1a = self.w1a.to(target.device, dtype=target.dtype) + w1b = self.w1b.to(target.device, dtype=target.dtype) + w2a = self.w2a.to(target.device, dtype=target.dtype) + w2b = self.w2b.to(target.device, dtype=target.dtype) + output_shape = [w1a.size(0), w1b.size(1)] + updown = (w2b @ w1b) + ((target @ w2a) @ w1a) + return self.finalize_updown(updown, target, output_shape) diff --git a/modules/lora/network_hada.py b/modules/lora/network_hada.py new file mode 100644 index 000000000..6fc142b3b --- /dev/null +++ b/modules/lora/network_hada.py @@ -0,0 +1,46 @@ +import modules.lora.lyco_helpers as lyco_helpers +import modules.lora.network as network + + +class ModuleTypeHada(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]): + return NetworkModuleHada(net, weights) + return None + + +class NetworkModuleHada(network.NetworkModule): # pylint: disable=abstract-method + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + self.w1a = weights.w["hada_w1_a"] + self.w1b = weights.w["hada_w1_b"] + self.dim = self.w1b.shape[0] + self.w2a = weights.w["hada_w2_a"] + self.w2b = weights.w["hada_w2_b"] + self.t1 = weights.w.get("hada_t1") + self.t2 = weights.w.get("hada_t2") + + def calc_updown(self, target): + w1a = self.w1a.to(target.device, dtype=target.dtype) + w1b = self.w1b.to(target.device, dtype=target.dtype) + w2a = self.w2a.to(target.device, dtype=target.dtype) + w2b = self.w2b.to(target.device, dtype=target.dtype) + output_shape = [w1a.size(0), w1b.size(1)] + if self.t1 is not None: + output_shape = [w1a.size(1), w1b.size(1)] + t1 = self.t1.to(target.device, dtype=target.dtype) + updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) + output_shape += t1.shape[2:] + else: + if len(w1b.shape) == 4: + output_shape += w1b.shape[2:] + updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) + if self.t2 is not None: + t2 = self.t2.to(target.device, dtype=target.dtype) + updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) + else: + updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) + updown = updown1 * updown2 + return self.finalize_updown(updown, target, output_shape) diff --git a/modules/lora/network_ia3.py b/modules/lora/network_ia3.py new file mode 100644 index 000000000..479e42526 --- /dev/null +++ b/modules/lora/network_ia3.py @@ -0,0 +1,24 @@ +import modules.lora.network as network + +class ModuleTypeIa3(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["weight"]): + return NetworkModuleIa3(net, weights) + return None + + +class NetworkModuleIa3(network.NetworkModule): # pylint: disable=abstract-method + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + self.w = weights.w["weight"] + self.on_input = weights.w["on_input"].item() + + def calc_updown(self, target): + w = self.w.to(target.device, dtype=target.dtype) + output_shape = [w.size(0), target.size(1)] + if self.on_input: + output_shape.reverse() + else: + w = w.reshape(-1, 1) + updown = target * w + return self.finalize_updown(updown, target, output_shape) diff --git a/modules/lora/network_lokr.py b/modules/lora/network_lokr.py new file mode 100644 index 000000000..877d4005b --- /dev/null +++ b/modules/lora/network_lokr.py @@ -0,0 +1,57 @@ +import torch +import modules.lora.lyco_helpers as lyco_helpers +import modules.lora.network as network + + +class ModuleTypeLokr(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w) + has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w) + if has_1 and has_2: + return NetworkModuleLokr(net, weights) + return None + + +def make_kron(orig_shape, w1, w2): + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + return torch.kron(w1, w2).reshape(orig_shape) + + +class NetworkModuleLokr(network.NetworkModule): # pylint: disable=abstract-method + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + self.w1 = weights.w.get("lokr_w1") + self.w1a = weights.w.get("lokr_w1_a") + self.w1b = weights.w.get("lokr_w1_b") + self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim + self.w2 = weights.w.get("lokr_w2") + self.w2a = weights.w.get("lokr_w2_a") + self.w2b = weights.w.get("lokr_w2_b") + self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim + self.t2 = weights.w.get("lokr_t2") + + def calc_updown(self, target): + if self.w1 is not None: + w1 = self.w1.to(target.device, dtype=target.dtype) + else: + w1a = self.w1a.to(target.device, dtype=target.dtype) + w1b = self.w1b.to(target.device, dtype=target.dtype) + w1 = w1a @ w1b + if self.w2 is not None: + w2 = self.w2.to(target.device, dtype=target.dtype) + elif self.t2 is None: + w2a = self.w2a.to(target.device, dtype=target.dtype) + w2b = self.w2b.to(target.device, dtype=target.dtype) + w2 = w2a @ w2b + else: + t2 = self.t2.to(target.device, dtype=target.dtype) + w2a = self.w2a.to(target.device, dtype=target.dtype) + w2b = self.w2b.to(target.device, dtype=target.dtype) + w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) + output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] + if len(target.shape) == 4: + output_shape = target.shape + updown = make_kron(output_shape, w1, w2) + return self.finalize_updown(updown, target, output_shape) diff --git a/modules/lora/network_lora.py b/modules/lora/network_lora.py new file mode 100644 index 000000000..3604e059d --- /dev/null +++ b/modules/lora/network_lora.py @@ -0,0 +1,75 @@ +import torch +import diffusers.models.lora as diffusers_lora +import modules.lora.lyco_helpers as lyco_helpers +import modules.lora.network as network +from modules import devices + + +class ModuleTypeLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]): + return NetworkModuleLora(net, weights) + return None + + +class NetworkModuleLora(network.NetworkModule): + + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + self.up_model = self.create_module(weights.w, "lora_up.weight") + self.down_model = self.create_module(weights.w, "lora_down.weight") + self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True) + self.dim = weights.w["lora_down.weight"].shape[0] + + def create_module(self, weights, key, none_ok=False): + weight = weights.get(key) + if weight is None and none_ok: + return None + linear_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear] + is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ in {"NNCFLinear", "QLinear", "Linear4bit"} + is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv] or self.sd_module.__class__.__name__ in {"NNCFConv2d", "QConv2d"} + if is_linear: + weight = weight.reshape(weight.shape[0], -1) + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif is_conv and (key == "lora_down.weight" or key == "dyn_up"): + if len(weight.shape) == 2: + weight = weight.reshape(weight.shape[0], -1, 1, 1) + if weight.shape[2] != 1 or weight.shape[3] != 1: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + else: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif is_conv and key == "lora_mid.weight": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + elif is_conv and (key == "lora_up.weight" or key == "dyn_down"): + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + raise AssertionError(f'Lora unsupported: layer={self.network_key} type={type(self.sd_module).__name__}') + with torch.no_grad(): + if weight.shape != module.weight.shape: + weight = weight.reshape(module.weight.shape) + module.weight.copy_(weight) + module.weight.requires_grad_(False) + return module + + def calc_updown(self, target): # pylint: disable=W0237 + target_dtype = target.dtype if target.dtype != torch.uint8 else self.up_model.weight.dtype + up = self.up_model.weight.to(target.device, dtype=target_dtype) + down = self.down_model.weight.to(target.device, dtype=target_dtype) + output_shape = [up.size(0), down.size(1)] + if self.mid_model is not None: + # cp-decomposition + mid = self.mid_model.weight.to(target.device, dtype=target_dtype) + updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) + output_shape += mid.shape[2:] + else: + if len(down.shape) == 4: + output_shape += down.shape[2:] + updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim) + return self.finalize_updown(updown, target, output_shape) + + def forward(self, x, y): + self.up_model.to(device=devices.device) + self.down_model.to(device=devices.device) + if hasattr(y, "scale"): + return y(scale=1) + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale() + return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale() diff --git a/modules/lora/network_norm.py b/modules/lora/network_norm.py new file mode 100644 index 000000000..5d059e92e --- /dev/null +++ b/modules/lora/network_norm.py @@ -0,0 +1,24 @@ +import modules.lora.network as network + + +class ModuleTypeNorm(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["w_norm", "b_norm"]): + return NetworkModuleNorm(net, weights) + return None + + +class NetworkModuleNorm(network.NetworkModule): # pylint: disable=abstract-method + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + self.w_norm = weights.w.get("w_norm") + self.b_norm = weights.w.get("b_norm") + + def calc_updown(self, target): + output_shape = self.w_norm.shape + updown = self.w_norm.to(target.device, dtype=target.dtype) + if self.b_norm is not None: + ex_bias = self.b_norm.to(target.device, dtype=target.dtype) + else: + ex_bias = None + return self.finalize_updown(updown, target, output_shape, ex_bias) diff --git a/modules/lora/network_oft.py b/modules/lora/network_oft.py new file mode 100644 index 000000000..e2e61ad45 --- /dev/null +++ b/modules/lora/network_oft.py @@ -0,0 +1,82 @@ +import torch +from einops import rearrange +import modules.lora.network as network +from modules.lora.lyco_helpers import factorization + + +class ModuleTypeOFT(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]): + return NetworkModuleOFT(net, weights) + return None + + +# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py +class NetworkModuleOFT(network.NetworkModule): # pylint: disable=abstract-method + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + self.lin_module = None + self.org_module: list[torch.Module] = [self.sd_module] + self.scale = 1.0 + + # kohya-ss + if "oft_blocks" in weights.w.keys(): + self.is_kohya = True + self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) + self.alpha = weights.w["alpha"] # alpha is constraint + self.dim = self.oft_blocks.shape[0] # lora dim + # LyCORIS + elif "oft_diag" in weights.w.keys(): + self.is_kohya = False + self.oft_blocks = weights.w["oft_diag"] + # self.alpha is unused + self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported + + if is_linear: + self.out_dim = self.sd_module.out_features + elif is_conv: + self.out_dim = self.sd_module.out_channels + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim + + if self.is_kohya: + self.constraint = self.alpha * self.out_dim + self.num_blocks = self.dim + self.block_size = self.out_dim // self.dim + else: + self.constraint = None + self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) + + def calc_updown(self, target): + oft_blocks = self.oft_blocks.to(target.device, dtype=target.dtype) + eye = torch.eye(self.block_size, device=target.device) + constraint = self.constraint.to(target.device) + + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()).to(target.device) + new_norm_Q = torch.clamp(norm_Q, max=constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + mat1 = eye + block_Q + mat2 = (eye - block_Q).float().inverse() + oft_blocks = torch.matmul(mat1, mat2) + + R = oft_blocks.to(target.device, dtype=target.dtype) + + # This errors out for MultiheadAttention, might need to be handled up-stream + merged_weight = rearrange(target, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R, + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + + updown = merged_weight.to(target.device, dtype=target.dtype) - target + output_shape = target.shape + return self.finalize_updown(updown, target, output_shape) diff --git a/modules/lora/network_overrides.py b/modules/lora/network_overrides.py new file mode 100644 index 000000000..5334f3c1b --- /dev/null +++ b/modules/lora/network_overrides.py @@ -0,0 +1,49 @@ +from modules import shared + + +maybe_diffusers = [ # forced if lora_maybe_diffusers is enabled + 'aaebf6360f7d', # sd15-lcm + '3d18b05e4f56', # sdxl-lcm + 'b71dcb732467', # sdxl-tcd + '813ea5fb1c67', # sdxl-turbo + # not really needed, but just in case + '5a48ac366664', # hyper-sd15-1step + 'ee0ff23dcc42', # hyper-sd15-2step + 'e476eb1da5df', # hyper-sd15-4step + 'ecb844c3f3b0', # hyper-sd15-8step + '1ab289133ebb', # hyper-sd15-8step-cfg + '4f494295edb1', # hyper-sdxl-8step + 'ca14a8c621f8', # hyper-sdxl-8step-cfg + '1c88f7295856', # hyper-sdxl-4step + 'fdd5dcd1d88a', # hyper-sdxl-2step + '8cca3706050b', # hyper-sdxl-1step +] + +force_diffusers = [ # forced always + '816d0eed49fd', # flash-sdxl + 'c2ec22757b46', # flash-sd15 +] + +force_models = [ # forced always + 'sc', + # 'sd3', + 'kandinsky', + 'hunyuandit', + 'auraflow', +] + +force_classes = [ # forced always +] + + +def check_override(shorthash=''): + force = False + force = force or (shared.sd_model_type in force_models) + force = force or (shared.sd_model.__class__.__name__ in force_classes) + if len(shorthash) < 4: + return force + force = force or (any(x.startswith(shorthash) for x in maybe_diffusers) if shared.opts.lora_maybe_diffusers else False) + force = force or any(x.startswith(shorthash) for x in force_diffusers) + if force and shared.opts.lora_maybe_diffusers: + shared.log.debug('LoRA override: force diffusers') + return force diff --git a/modules/lora/networks.py b/modules/lora/networks.py new file mode 100644 index 000000000..fc90ddd2d --- /dev/null +++ b/modules/lora/networks.py @@ -0,0 +1,601 @@ +from typing import Union, List +from contextlib import nullcontext +import os +import re +import time +import concurrent +import torch +import diffusers.models.lora +import rich.progress as rp + +from modules.lora import lora_timers, network, lora_convert, network_overrides +from modules.lora import network_lora, network_hada, network_ia3, network_oft, network_lokr, network_full, network_norm, network_glora +from modules.lora.extra_networks_lora import ExtraNetworkLora +from modules import shared, devices, sd_models, sd_models_compile, errors, files_cache, model_quant + + +debug = os.environ.get('SD_LORA_DEBUG', None) is not None +extra_network_lora = ExtraNetworkLora() +available_networks = {} +available_network_aliases = {} +loaded_networks: List[network.Network] = [] +applied_layers: list[str] = [] +bnb = None +lora_cache = {} +diffuser_loaded = [] +diffuser_scales = [] +available_network_hash_lookup = {} +forbidden_network_aliases = {} +re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") +timer = lora_timers.Timer() +module_types = [ + network_lora.ModuleTypeLora(), + network_hada.ModuleTypeHada(), + network_ia3.ModuleTypeIa3(), + network_oft.ModuleTypeOFT(), + network_lokr.ModuleTypeLokr(), + network_full.ModuleTypeFull(), + network_norm.ModuleTypeNorm(), + network_glora.ModuleTypeGLora(), +] + +# section: load networks from disk + +def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_default_multiplier) -> Union[network.Network, None]: + t0 = time.time() + name = name.replace(".", "_") + shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" detected={network_on_disk.sd_version} method=diffusers scale={lora_scale} fuse={shared.opts.lora_fuse_diffusers}') + if not shared.native: + return None + if not hasattr(shared.sd_model, 'load_lora_weights'): + shared.log.error(f'Load network: type=LoRA class={shared.sd_model.__class__} does not implement load lora') + return None + try: + shared.sd_model.load_lora_weights(network_on_disk.filename, adapter_name=name) + except Exception as e: + if 'already in use' in str(e): + pass + else: + if 'The following keys have not been correctly renamed' in str(e): + shared.log.error(f'Load network: type=LoRA name="{name}" diffusers unsupported format') + else: + shared.log.error(f'Load network: type=LoRA name="{name}" {e}') + if debug: + errors.display(e, "LoRA") + return None + if name not in diffuser_loaded: + diffuser_loaded.append(name) + diffuser_scales.append(lora_scale) + net = network.Network(name, network_on_disk) + net.mtime = os.path.getmtime(network_on_disk.filename) + timer.activate += time.time() - t0 + return net + + +def load_safetensors(name, network_on_disk) -> Union[network.Network, None]: + if not shared.sd_loaded: + return None + + cached = lora_cache.get(name, None) + if debug: + shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" type=lora {"cached" if cached else ""}') + if cached is not None: + return cached + net = network.Network(name, network_on_disk) + net.mtime = os.path.getmtime(network_on_disk.filename) + sd = sd_models.read_state_dict(network_on_disk.filename, what='network') + if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict + sd = lora_convert._convert_kohya_flux_lora_to_diffusers(sd) or sd # pylint: disable=protected-access + if shared.sd_model_type == 'sd3': # if kohya flux lora, convert state_dict + try: + sd = lora_convert._convert_kohya_sd3_lora_to_diffusers(sd) or sd # pylint: disable=protected-access + except ValueError: # EAFP for diffusers PEFT keys + pass + lora_convert.assign_network_names_to_compvis_modules(shared.sd_model) + keys_failed_to_match = {} + matched_networks = {} + bundle_embeddings = {} + convert = lora_convert.KeyConvert() + for key_network, weight in sd.items(): + parts = key_network.split('.') + if parts[0] == "bundle_emb": + emb_name, vec_name = parts[1], key_network.split(".", 2)[-1] + emb_dict = bundle_embeddings.get(emb_name, {}) + emb_dict[vec_name] = weight + bundle_embeddings[emb_name] = emb_dict + continue + if len(parts) > 5: # messy handler for diffusers peft lora + key_network_without_network_parts = '_'.join(parts[:-2]) + if not key_network_without_network_parts.startswith('lora_'): + key_network_without_network_parts = 'lora_' + key_network_without_network_parts + network_part = '.'.join(parts[-2:]).replace('lora_A', 'lora_down').replace('lora_B', 'lora_up') + else: + key_network_without_network_parts, network_part = key_network.split(".", 1) + key, sd_module = convert(key_network_without_network_parts) + if sd_module is None: + keys_failed_to_match[key_network] = key + continue + if key not in matched_networks: + matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) + matched_networks[key].w[network_part] = weight + network_types = [] + for key, weights in matched_networks.items(): + net_module = None + for nettype in module_types: + net_module = nettype.create_module(net, weights) + if net_module is not None: + network_types.append(nettype.__class__.__name__) + break + if net_module is None: + shared.log.error(f'LoRA unhandled: name={name} key={key} weights={weights.w.keys()}') + else: + net.modules[key] = net_module + if len(keys_failed_to_match) > 0: + shared.log.warning(f'LoRA name="{name}" type={set(network_types)} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}') + if debug: + shared.log.debug(f'LoRA name="{name}" unmatched={keys_failed_to_match}') + else: + shared.log.debug(f'LoRA name="{name}" type={set(network_types)} keys={len(matched_networks)} direct={shared.opts.lora_fuse_diffusers}') + if len(matched_networks) == 0: + return None + lora_cache[name] = net + net.bundle_embeddings = bundle_embeddings + return net + + +def maybe_recompile_model(names, te_multipliers): + recompile_model = False + if shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled: + if len(names) == len(shared.compiled_model_state.lora_model): + for i, name in enumerate(names): + if shared.compiled_model_state.lora_model[ + i] != f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}": + recompile_model = True + shared.compiled_model_state.lora_model = [] + break + if not recompile_model: + if len(loaded_networks) > 0 and debug: + shared.log.debug('Model Compile: Skipping LoRa loading') + return recompile_model + else: + recompile_model = True + shared.compiled_model_state.lora_model = [] + if recompile_model: + backup_cuda_compile = shared.opts.cuda_compile + sd_models.unload_model_weights(op='model') + shared.opts.cuda_compile = [] + sd_models.reload_model_weights(op='model') + shared.opts.cuda_compile = backup_cuda_compile + return recompile_model + + +def list_available_networks(): + t0 = time.time() + available_networks.clear() + available_network_aliases.clear() + forbidden_network_aliases.clear() + available_network_hash_lookup.clear() + forbidden_network_aliases.update({"none": 1, "Addams": 1}) + if not os.path.exists(shared.cmd_opts.lora_dir): + shared.log.warning(f'LoRA directory not found: path="{shared.cmd_opts.lora_dir}"') + + def add_network(filename): + if not os.path.isfile(filename): + return + name = os.path.splitext(os.path.basename(filename))[0] + name = name.replace('.', '_') + try: + entry = network.NetworkOnDisk(name, filename) + available_networks[entry.name] = entry + if entry.alias in available_network_aliases: + forbidden_network_aliases[entry.alias.lower()] = 1 + if shared.opts.lora_preferred_name == 'filename': + available_network_aliases[entry.name] = entry + else: + available_network_aliases[entry.alias] = entry + if entry.shorthash: + available_network_hash_lookup[entry.shorthash] = entry + except OSError as e: # should catch FileNotFoundError and PermissionError etc. + shared.log.error(f'LoRA: filename="{filename}" {e}') + + candidates = list(files_cache.list_files(shared.cmd_opts.lora_dir, ext_filter=[".pt", ".ckpt", ".safetensors"])) + with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor: + for fn in candidates: + executor.submit(add_network, fn) + t1 = time.time() + timer.list = t1 - t0 + shared.log.info(f'Available LoRAs: path="{shared.cmd_opts.lora_dir}" items={len(available_networks)} folders={len(forbidden_network_aliases)} time={t1 - t0:.2f}') + + +def network_download(name): + from huggingface_hub import hf_hub_download + if os.path.exists(name): + return network.NetworkOnDisk(name, name) + parts = name.split('/') + if len(parts) >= 5 and parts[1] == 'huggingface.co': + repo_id = f'{parts[2]}/{parts[3]}' + filename = '/'.join(parts[4:]) + fn = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=shared.opts.hfcache_dir) + return network.NetworkOnDisk(name, fn) + return None + + +def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): + networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names] + if any(x is None for x in networks_on_disk): + list_available_networks() + networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names] + for i in range(len(names)): + if names[i].startswith('/'): + networks_on_disk[i] = network_download(names[i]) + failed_to_load_networks = [] + recompile_model = maybe_recompile_model(names, te_multipliers) + + loaded_networks.clear() + diffuser_loaded.clear() + diffuser_scales.clear() + t0 = time.time() + + for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): + net = None + if network_on_disk is not None: + shorthash = getattr(network_on_disk, 'shorthash', '').lower() + if debug: + shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" hash="{shorthash}"') + try: + if recompile_model: + shared.compiled_model_state.lora_model.append(f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}") + if shared.opts.lora_force_diffusers or network_overrides.check_override(shorthash): # OpenVINO only works with Diffusers LoRa loading + net = load_diffusers(name, network_on_disk, lora_scale=te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier) + else: + net = load_safetensors(name, network_on_disk) + if net is not None: + net.mentioned_name = name + network_on_disk.read_hash() + except Exception as e: + shared.log.error(f'Load network: type=LoRA file="{network_on_disk.filename}" {e}') + if debug: + errors.display(e, 'LoRA') + continue + if net is None: + failed_to_load_networks.append(name) + shared.log.error(f'Load network: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} failed') + continue + shared.sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings) + net.te_multiplier = te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier + net.unet_multiplier = unet_multipliers[i] if unet_multipliers else shared.opts.extra_networks_default_multiplier + net.dyn_dim = dyn_dims[i] if dyn_dims else shared.opts.extra_networks_default_multiplier + loaded_networks.append(net) + + while len(lora_cache) > shared.opts.lora_in_memory_limit: + name = next(iter(lora_cache)) + lora_cache.pop(name, None) + + if len(diffuser_loaded) > 0: + shared.log.debug(f'Load network: type=LoRA loaded={diffuser_loaded} available={shared.sd_model.get_list_adapters()} active={shared.sd_model.get_active_adapters()} scales={diffuser_scales}') + try: + t0 = time.time() + shared.sd_model.set_adapters(adapter_names=diffuser_loaded, adapter_weights=diffuser_scales) + if shared.opts.lora_fuse_diffusers: + shared.sd_model.fuse_lora(adapter_names=diffuser_loaded, lora_scale=1.0, fuse_unet=True, fuse_text_encoder=True) # fuse uses fixed scale since later apply does the scaling + shared.sd_model.unload_lora_weights() + timer.activate += time.time() - t0 + except Exception as e: + shared.log.error(f'Load network: type=LoRA {e}') + if debug: + errors.display(e, 'LoRA') + + if len(loaded_networks) > 0 and debug: + shared.log.debug(f'Load network: type=LoRA loaded={len(loaded_networks)} cache={list(lora_cache)}') + + if recompile_model: + shared.log.info("Load network: type=LoRA recompiling model") + backup_lora_model = shared.compiled_model_state.lora_model + if 'Model' in shared.opts.cuda_compile: + shared.sd_model = sd_models_compile.compile_diffusers(shared.sd_model) + + shared.compiled_model_state.lora_model = backup_lora_model + + if len(loaded_networks) > 0: + devices.torch_gc() + + timer.load = time.time() - t0 + + +# section: process loaded networks + +def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, wanted_names: tuple): + global bnb # pylint: disable=W0603 + backup_size = 0 + if len(loaded_networks) > 0 and network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in loaded_networks]): # noqa: C419 # pylint: disable=R1729 + t0 = time.time() + + weights_backup = getattr(self, "network_weights_backup", None) + if weights_backup is None and wanted_names != (): # pylint: disable=C1803 + weight = getattr(self, 'weight', None) + self.network_weights_backup = None + if getattr(weight, "quant_type", None) in ['nf4', 'fp4']: + if bnb is None: + bnb = model_quant.load_bnb('Load network: type=LoRA', silent=True) + if bnb is not None: + with devices.inference_context(): + if shared.opts.lora_fuse_diffusers: + self.network_weights_backup = True + else: + self.network_weights_backup = bnb.functional.dequantize_4bit(weight, quant_state=weight.quant_state, quant_type=weight.quant_type, blocksize=weight.blocksize,) + self.quant_state = weight.quant_state + self.quant_type = weight.quant_type + self.blocksize = weight.blocksize + else: + if shared.opts.lora_fuse_diffusers: + self.network_weights_backup = True + else: + weights_backup = weight.clone() + self.network_weights_backup = weights_backup.to(devices.cpu) + else: + if shared.opts.lora_fuse_diffusers: + self.network_weights_backup = True + else: + self.network_weights_backup = weight.clone().to(devices.cpu) + + bias_backup = getattr(self, "network_bias_backup", None) + if bias_backup is None: + if getattr(self, 'bias', None) is not None: + if shared.opts.lora_fuse_diffusers: + self.network_bias_backup = True + else: + bias_backup = self.bias.clone() + bias_backup = bias_backup.to(devices.cpu) + + if getattr(self, 'network_weights_backup', None) is not None: + backup_size += self.network_weights_backup.numel() * self.network_weights_backup.element_size() if isinstance(self.network_weights_backup, torch.Tensor) else 0 + if getattr(self, 'network_bias_backup', None) is not None: + backup_size += self.network_bias_backup.numel() * self.network_bias_backup.element_size() if isinstance(self.network_bias_backup, torch.Tensor) else 0 + timer.backup += time.time() - t0 + return backup_size + + +def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str): + if shared.opts.diffusers_offload_mode == "none": + try: + self.to(devices.device) + except Exception: + pass + batch_updown = None + batch_ex_bias = None + for net in loaded_networks: + module = net.modules.get(network_layer_name, None) + if module is None: + continue + try: + t0 = time.time() + try: + weight = self.weight.to(devices.device) + except Exception: + weight = self.weight + updown, ex_bias = module.calc_updown(weight) + if batch_updown is not None and updown is not None: + batch_updown += updown.to(batch_updown.device) + else: + batch_updown = updown + if batch_ex_bias is not None and ex_bias is not None: + batch_ex_bias += ex_bias.to(batch_ex_bias.device) + else: + batch_ex_bias = ex_bias + timer.calc += time.time() - t0 + if shared.opts.diffusers_offload_mode == "sequential": + t0 = time.time() + if batch_updown is not None: + batch_updown = batch_updown.to(devices.cpu) + if batch_ex_bias is not None: + batch_ex_bias = batch_ex_bias.to(devices.cpu) + t1 = time.time() + timer.move += t1 - t0 + except RuntimeError as e: + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 + if debug: + module_name = net.modules.get(network_layer_name, None) + shared.log.error(f'LoRA apply weight name="{net.name}" module="{module_name}" layer="{network_layer_name}" {e}') + errors.display(e, 'LoRA') + raise RuntimeError('LoRA apply weight') from e + continue + return batch_updown, batch_ex_bias + + +def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False): + weights_backup = getattr(self, "network_weights_backup", False) + bias_backup = getattr(self, "network_bias_backup", False) + if not weights_backup and not bias_backup: + return None, None + t0 = time.time() + + if weights_backup: + if updown is not None and len(self.weight.shape) == 4 and self.weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 + updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable + if updown is not None: + if deactivate: + updown *= -1 + if getattr(self, "quant_type", None) in ['nf4', 'fp4'] and bnb is not None: + try: # TODO lora load: direct with bnb + weight = bnb.functional.dequantize_4bit(self.weight, quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize) + new_weight = weight.to(devices.device) + updown.to(devices.device) + self.weight = bnb.nn.Params4bit(new_weight, quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize) + except Exception: + # shared.log.error(f'Load network: type=LoRA quant=bnb type={self.quant_type} state={self.quant_state} blocksize={self.blocksize} {e}') + extra_network_lora.errors['bnb'] = extra_network_lora.errors.get('bnb', 0) + 1 + new_weight = None + else: + try: + new_weight = self.weight.to(devices.device) + updown.to(devices.device) + except Exception: + new_weight = self.weight + updown + self.weight = torch.nn.Parameter(new_weight, requires_grad=False) + del new_weight + if hasattr(self, "qweight") and hasattr(self, "freeze"): + self.freeze() + + if bias_backup: + if ex_bias is not None: + if deactivate: + ex_bias *= -1 + new_weight = bias_backup.to(devices.device) + ex_bias.to(devices.device) + self.bias = torch.nn.Parameter(new_weight, requires_grad=False) + del new_weight + + timer.apply += time.time() - t0 + return self.weight.device, self.weight.dtype + + +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, orig_device: torch.device, deactivate: bool = False): + weights_backup = getattr(self, "network_weights_backup", None) + bias_backup = getattr(self, "network_bias_backup", None) + if weights_backup is None and bias_backup is None: + return None, None + t0 = time.time() + + if weights_backup is not None: + self.weight = None + if updown is not None and len(weights_backup.shape) == 4 and weights_backup.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 + updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable + if updown is not None: + if deactivate: + updown *= -1 + new_weight = weights_backup.to(devices.device) + updown.to(devices.device) + if getattr(self, "quant_type", None) in ['nf4', 'fp4'] and bnb is not None: + self.weight = bnb.nn.Params4bit(new_weight, quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize) + else: + self.weight = torch.nn.Parameter(new_weight.to(device=orig_device), requires_grad=False) + del new_weight + else: + self.weight = torch.nn.Parameter(weights_backup.to(device=orig_device), requires_grad=False) + if hasattr(self, "qweight") and hasattr(self, "freeze"): + self.freeze() + + if bias_backup is not None: + self.bias = None + if ex_bias is not None: + if deactivate: + ex_bias *= -1 + new_weight = bias_backup.to(devices.device) + ex_bias.to(devices.device) + self.bias = torch.nn.Parameter(new_weight.to(device=orig_device), requires_grad=False) + del new_weight + else: + self.bias = torch.nn.Parameter(bias_backup.to(device=orig_device), requires_grad=False) + + timer.apply += time.time() - t0 + return self.weight.device, self.weight.dtype + + +def network_deactivate(): + if not shared.opts.lora_fuse_diffusers: + return + t0 = time.time() + timer.clear() + sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility + if shared.opts.diffusers_offload_mode == "sequential": + sd_models.disable_offload(sd_model) + sd_models.move_model(sd_model, device=devices.cpu) + modules = {} + for component_name in ['text_encoder', 'text_encoder_2', 'unet', 'transformer']: + component = getattr(sd_model, component_name, None) + if component is not None and hasattr(component, 'named_modules'): + modules[component_name] = list(component.named_modules()) + total = sum(len(x) for x in modules.values()) + if len(loaded_networks) > 0: + pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=deactivate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console) + task = pbar.add_task(description='', total=total) + else: + task = None + pbar = nullcontext() + with devices.inference_context(), pbar: + applied_layers.clear() + weights_devices = [] + weights_dtypes = [] + for component in modules.keys(): + orig_device = getattr(sd_model, component, None).device + for _, module in modules[component]: + network_layer_name = getattr(module, 'network_layer_name', None) + if shared.state.interrupted or network_layer_name is None: + if task is not None: + pbar.update(task, advance=1) + continue + batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name) + if shared.opts.lora_fuse_diffusers: + weights_device, weights_dtype = network_apply_direct(module, batch_updown, batch_ex_bias, deactivate=True) + else: + weights_device, weights_dtype = network_apply_weights(module, batch_updown, batch_ex_bias, orig_device, deactivate=True) + weights_devices.append(weights_device) + weights_dtypes.append(weights_dtype) + if batch_updown is not None or batch_ex_bias is not None: + applied_layers.append(network_layer_name) + del batch_updown, batch_ex_bias + module.network_current_names = () + if task is not None: + pbar.update(task, advance=1, description=f'networks={len(loaded_networks)} modules={len(modules)} deactivate={len(applied_layers)}') + weights_devices, weights_dtypes = list(set([x for x in weights_devices if x is not None])), list(set([x for x in weights_dtypes if x is not None])) # noqa: C403 # pylint: disable=R1718 + timer.deactivate = time.time() - t0 + if debug and len(loaded_networks) > 0: + shared.log.debug(f'Deactivate network: type=LoRA networks={len(loaded_networks)} modules={total} deactivate={len(applied_layers)} device={weights_devices} dtype={weights_dtypes} fuse={shared.opts.lora_fuse_diffusers} time={timer.summary}') + modules.clear() + if shared.opts.diffusers_offload_mode == "sequential": + sd_models.set_diffuser_offload(sd_model, op="model") + + +def network_activate(include=[], exclude=[]): + t0 = time.time() + sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility + if shared.opts.diffusers_offload_mode == "sequential": + sd_models.disable_offload(sd_model) + sd_models.move_model(sd_model, device=devices.cpu) + modules = {} + components = include if len(include) > 0 else ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'unet', 'transformer'] + components = [x for x in components if x not in exclude] + for name in components: + component = getattr(sd_model, name, None) + if component is not None and hasattr(component, 'named_modules'): + modules[name] = list(component.named_modules()) + total = sum(len(x) for x in modules.values()) + if len(loaded_networks) > 0: + pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=activate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console) + task = pbar.add_task(description='' , total=total) + else: + task = None + pbar = nullcontext() + with devices.inference_context(), pbar: + wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) if len(loaded_networks) > 0 else () + applied_layers.clear() + backup_size = 0 + weights_devices = [] + weights_dtypes = [] + for component in modules.keys(): + orig_device = getattr(sd_model, component, None).device + for _, module in modules[component]: + network_layer_name = getattr(module, 'network_layer_name', None) + current_names = getattr(module, "network_current_names", ()) + if getattr(module, 'weight', None) is None or shared.state.interrupted or network_layer_name is None or current_names == wanted_names: + if task is not None: + pbar.update(task, advance=1) + continue + backup_size += network_backup_weights(module, network_layer_name, wanted_names) + batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name) + if shared.opts.lora_fuse_diffusers: + weights_device, weights_dtype = network_apply_direct(module, batch_updown, batch_ex_bias) + else: + weights_device, weights_dtype = network_apply_weights(module, batch_updown, batch_ex_bias, orig_device) + weights_devices.append(weights_device) + weights_dtypes.append(weights_dtype) + if batch_updown is not None or batch_ex_bias is not None: + applied_layers.append(network_layer_name) + del batch_updown, batch_ex_bias + module.network_current_names = wanted_names + if task is not None: + pbar.update(task, advance=1, description=f'networks={len(loaded_networks)} modules={total} apply={len(applied_layers)} backup={backup_size}') + if task is not None and len(applied_layers) == 0: + pbar.remove_task(task) # hide progress bar for no action + weights_devices, weights_dtypes = list(set([x for x in weights_devices if x is not None])), list(set([x for x in weights_dtypes if x is not None])) # noqa: C403 # pylint: disable=R1718 + timer.activate += time.time() - t0 + if debug and len(loaded_networks) > 0: + shared.log.debug(f'Load network: type=LoRA networks={len(loaded_networks)} components={components} modules={total} apply={len(applied_layers)} device={weights_devices} dtype={weights_dtypes} backup={backup_size} fuse={shared.opts.lora_fuse_diffusers} time={timer.summary}') + modules.clear() + if shared.opts.diffusers_offload_mode == "sequential": + sd_models.set_diffuser_offload(sd_model, op="model") diff --git a/modules/meissonic/pipeline_img2img.py b/modules/meissonic/pipeline_img2img.py index f26af123d..13e5c3717 100644 --- a/modules/meissonic/pipeline_img2img.py +++ b/modules/meissonic/pipeline_img2img.py @@ -56,9 +56,6 @@ class Img2ImgPipeline(DiffusionPipeline): model_cpu_offload_seq = "text_encoder->transformer->vqvae" - # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before - # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter - # off the meta device. There should be a way to fix this instead of just not offloading it _exclude_from_cpu_offload = ["vqvae"] def __init__( diff --git a/modules/meissonic/pipeline_inpaint.py b/modules/meissonic/pipeline_inpaint.py index 994846fba..d405afa53 100644 --- a/modules/meissonic/pipeline_inpaint.py +++ b/modules/meissonic/pipeline_inpaint.py @@ -53,9 +53,6 @@ class InpaintPipeline(DiffusionPipeline): model_cpu_offload_seq = "text_encoder->transformer->vqvae" - # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before - # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter - # off the meta device. There should be a way to fix this instead of just not offloading it _exclude_from_cpu_offload = ["vqvae"] def __init__( diff --git a/modules/meissonic/transformer.py b/modules/meissonic/transformer.py index 64f91baa2..43e77ddc7 100644 --- a/modules/meissonic/transformer.py +++ b/modules/meissonic/transformer.py @@ -341,11 +341,6 @@ def __call__( key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - # if image_rotary_emb is not None: # TODO broken import - # from .embeddings import apply_rotary_emb - # query = apply_rotary_emb(query, image_rotary_emb) - # key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) diff --git a/modules/memmon.py b/modules/memmon.py index 6887e1e1c..d9fa3963d 100644 --- a/modules/memmon.py +++ b/modules/memmon.py @@ -42,14 +42,14 @@ def read(self): if not self.disabled: try: self.data["free"], self.data["total"] = torch.cuda.mem_get_info(self.device.index if self.device.index is not None else torch.cuda.current_device()) + self.data["used"] = self.data["total"] - self.data["free"] torch_stats = torch.cuda.memory_stats(self.device) - self.data["active"] = torch_stats["active.all.current"] + self.data["active"] = torch_stats.get("active.all.current", torch_stats["active_bytes.all.current"]) self.data["active_peak"] = torch_stats["active_bytes.all.peak"] self.data["reserved"] = torch_stats["reserved_bytes.all.current"] self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] - self.data['retries'] = torch_stats["num_alloc_retries"] - self.data['oom'] = torch_stats["num_ooms"] - self.data["used"] = self.data["total"] - self.data["free"] + self.data['retries'] = torch_stats.get("num_alloc_retries", -1) + self.data['oom'] = torch_stats.get("num_ooms", -1) except Exception: self.disabled = True return self.data diff --git a/modules/memstats.py b/modules/memstats.py index c417165a2..d43e5bbfa 100644 --- a/modules/memstats.py +++ b/modules/memstats.py @@ -4,13 +4,16 @@ from modules import shared, errors fail_once = False +mem = {} + + +def gb(val: float): + return round(val / 1024 / 1024 / 1024, 2) + def memory_stats(): global fail_once # pylint: disable=global-statement - def gb(val: float): - return round(val / 1024 / 1024 / 1024, 2) - - mem = {} + mem.clear() try: process = psutil.Process(os.getpid()) res = process.memory_info() @@ -19,10 +22,10 @@ def gb(val: float): mem.update({ 'ram': ram }) except Exception as e: if not fail_once: - shared.log.error('Memory stats: {e}') + shared.log.error(f'Memory stats: {e}') errors.display(e, 'Memory stats') fail_once = True - mem.update({ 'ram': str(e) }) + mem.update({ 'ram': { 'error': str(e) } }) try: s = torch.cuda.mem_get_info() gpu = { 'used': gb(s[1] - s[0]), 'total': gb(s[1]) } @@ -38,3 +41,18 @@ def gb(val: float): except Exception: pass return mem + + +def memory_cache(): + return mem + + +def ram_stats(): + try: + process = psutil.Process(os.getpid()) + res = process.memory_info() + ram_total = 100 * res.rss / process.memory_percent() + ram = { 'used': gb(res.rss), 'total': gb(ram_total) } + return ram + except Exception: + return { 'used': 0, 'total': 0 } diff --git a/modules/merging/merge_methods.py b/modules/merging/merge_methods.py index 3f704c20f..ce196b60c 100644 --- a/modules/merging/merge_methods.py +++ b/modules/merging/merge_methods.py @@ -4,7 +4,7 @@ import torch from torch import Tensor -__all__ = [ +__all__ = [ # noqa: RUF022 "weighted_sum", "weighted_subtraction", "tensor_sum", diff --git a/modules/model_flux.py b/modules/model_flux.py index 17234d9a4..f2286866e 100644 --- a/modules/model_flux.py +++ b/modules/model_flux.py @@ -34,7 +34,8 @@ def load_flux_quanto(checkpoint_info): with torch.device("meta"): transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype) quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu")) - transformer.eval() + if shared.opts.diffusers_eval: + transformer.eval() if transformer.dtype != devices.dtype: try: transformer = transformer.to(dtype=devices.dtype) @@ -61,7 +62,8 @@ def load_flux_quanto(checkpoint_info): with torch.device("meta"): text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype) quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu")) - text_encoder_2.eval() + if shared.opts.diffusers_eval: + text_encoder_2.eval() if text_encoder_2.dtype != devices.dtype: try: text_encoder_2 = text_encoder_2.to(dtype=devices.dtype) @@ -108,6 +110,7 @@ def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unu return transformer, text_encoder_2 +""" def quant_flux_bnb(checkpoint_info, transformer, text_encoder_2): repo_id = sd_models.path_to_repo(checkpoint_info.name) cache_dir=shared.opts.diffusers_dir @@ -137,18 +140,36 @@ def quant_flux_bnb(checkpoint_info, transformer, text_encoder_2): from modules import errors errors.display(e, 'FLUX:') return transformer, text_encoder_2 +""" +def load_quants(kwargs, repo_id, cache_dir): + if len(shared.opts.bnb_quantization) > 0: + quant_args = {} + quant_args = model_quant.create_bnb_config(quant_args) + quant_args = model_quant.create_ao_config(quant_args) + if not quant_args: + return kwargs + model_quant.load_bnb(f'Load model: type=FLUX quant={quant_args}') + if 'Model' in shared.opts.bnb_quantization and 'transformer' not in kwargs: + kwargs['transformer'] = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args) + shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}') + if 'Text Encoder' in shared.opts.bnb_quantization and 'text_encoder_3' not in kwargs: + kwargs['text_encoder_2'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args) + shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}') + return kwargs + + +""" def load_flux_gguf(file_path): transformer = None - model_te.install_gguf() + ggml.install_gguf() from accelerate import init_empty_weights from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers from modules import ggml, sd_hijack_accelerate with init_empty_weights(): - from diffusers import FluxTransformer2DModel - config = FluxTransformer2DModel.load_config(os.path.join('configs', 'flux'), subfolder="transformer") - transformer = FluxTransformer2DModel.from_config(config).to(devices.dtype) + config = diffusers.FluxTransformer2DModel.load_config(os.path.join('configs', 'flux'), subfolder="transformer") + transformer = diffusers.FluxTransformer2DModel.from_config(config).to(devices.dtype) expected_state_dict_keys = list(transformer.state_dict().keys()) state_dict, stats = ggml.load_gguf_state_dict(file_path, devices.dtype) state_dict = convert_flux_transformer_checkpoint_to_diffusers(state_dict) @@ -160,9 +181,11 @@ def load_flux_gguf(file_path): continue applied += 1 sd_hijack_accelerate.hijack_set_module_tensor_simple(transformer, tensor_name=param_name, value=param, device=0) + transformer.gguf = 'gguf' state_dict[param_name] = None shared.log.debug(f'Load model: type=Unet/Transformer applied={applied} skipped={skipped} stats={stats}') return transformer, None +""" def load_transformer(file_path): # triggered by opts.sd_unet change @@ -177,7 +200,9 @@ def load_transformer(file_path): # triggered by opts.sd_unet change } shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant={quant} dtype={devices.dtype}') if 'gguf' in file_path.lower(): - _transformer, _text_encoder_2 = load_flux_gguf(file_path) + # _transformer, _text_encoder_2 = load_flux_gguf(file_path) + from modules import ggml + _transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype) if _transformer is not None: transformer = _transformer elif quant == 'qint8' or quant == 'qint4': @@ -188,13 +213,14 @@ def load_transformer(file_path): # triggered by opts.sd_unet change _transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config) if _transformer is not None: transformer = _transformer - elif 'nf4' in quant: # TODO right now this is not working for civitai published nf4 models + elif 'nf4' in quant: # TODO flux: fix loader for civitai nf4 models from modules.model_flux_nf4 import load_flux_nf4 _transformer, _text_encoder_2 = load_flux_nf4(file_path) if _transformer is not None: transformer = _transformer else: diffusers_load_config = model_quant.create_bnb_config(diffusers_load_config) + diffusers_load_config = model_quant.create_ao_config(diffusers_load_config) transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **diffusers_load_config) if transformer is None: shared.log.error('Failed to load UNet model') @@ -223,10 +249,8 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch if shared.opts.sd_unet != 'None': try: debug(f'Load model: type=FLUX unet="{shared.opts.sd_unet}"') - _transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet]) - if _transformer is not None: - transformer = _transformer - else: + transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet]) + if transformer is None: shared.opts.sd_unet = 'None' sd_unet.failed_unet.append(shared.opts.sd_unet) except Exception as e: @@ -294,7 +318,6 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch # initialize pipeline with pre-loaded components kwargs = {} - # transformer, text_encoder_2 = quant_flux_bnb(checkpoint_info, transformer, text_encoder_2) if transformer is not None: kwargs['transformer'] = transformer sd_unet.loaded_unet = shared.opts.sd_unet @@ -306,26 +329,46 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch model_te.loaded_te = shared.opts.sd_text_encoder if vae is not None: kwargs['vae'] = vae - shared.log.debug(f'Load model: type=FLUX preloaded={list(kwargs)}') if repo_id == 'sayakpaul/flux.1-dev-nf4': repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json + if 'Fill' in repo_id: + cls = diffusers.FluxFillPipeline + elif 'Canny' in repo_id: + cls = diffusers.FluxControlPipeline + elif 'Depth' in repo_id: + cls = diffusers.FluxControlPipeline + else: + cls = diffusers.FluxPipeline + shared.log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}') for c in kwargs: + if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None: + shared.log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}') if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32: - shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast') - kwargs[c] = kwargs[c].to(dtype=devices.dtype) + try: + kwargs[c] = kwargs[c].to(dtype=devices.dtype) + shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast') + except Exception: + pass - allow_bnb = 'gguf' not in (sd_unet.loaded_unet or '') - kwargs = model_quant.create_bnb_config(kwargs, allow_bnb) - if checkpoint_info.path.endswith('.safetensors') and os.path.isfile(checkpoint_info.path): - pipe = diffusers.FluxPipeline.from_single_file(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config) + allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') + fn = checkpoint_info.path + if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)): + # transformer, text_encoder_2 = quant_flux_bnb(checkpoint_info, transformer, text_encoder_2) + kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir) + kwargs = model_quant.create_bnb_config(kwargs, allow_quant) + kwargs = model_quant.create_ao_config(kwargs, allow_quant) + if fn.endswith('.safetensors') and os.path.isfile(fn): + pipe = diffusers.FluxPipeline.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config) else: - pipe = diffusers.FluxPipeline.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config) + pipe = cls.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config) # release memory transformer = None text_encoder_1 = None text_encoder_2 = None vae = None + for k in kwargs.keys(): + kwargs[k] = None devices.torch_gc() return pipe diff --git a/modules/model_omnigen.py b/modules/model_omnigen.py index 64c99ddd7..a08ad4ed5 100644 --- a/modules/model_omnigen.py +++ b/modules/model_omnigen.py @@ -17,7 +17,8 @@ def load_omnigen(checkpoint_info, diffusers_load_config={}): # pylint: disable=u pipe.separate_cfg_infer = True pipe.use_kv_cache = False pipe.model.to(device=devices.device, dtype=devices.dtype) - pipe.model.eval() + if shared.opts.diffusers_eval: + pipe.model.eval() pipe.vae.to(devices.device, dtype=devices.dtype) devices.torch_gc() diff --git a/modules/model_quant.py b/modules/model_quant.py index 0e7bdd4b3..10e5eb99e 100644 --- a/modules/model_quant.py +++ b/modules/model_quant.py @@ -5,6 +5,7 @@ bnb = None quanto = None +ao = None def create_bnb_config(kwargs = None, allow_bnb: bool = True): @@ -12,6 +13,8 @@ def create_bnb_config(kwargs = None, allow_bnb: bool = True): if len(shared.opts.bnb_quantization) > 0 and allow_bnb: if 'Model' in shared.opts.bnb_quantization: load_bnb() + if bnb is None: + return kwargs bnb_config = diffusers.BitsAndBytesConfig( load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'], load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'], @@ -28,11 +31,49 @@ def create_bnb_config(kwargs = None, allow_bnb: bool = True): return kwargs +def create_ao_config(kwargs = None, allow_ao: bool = True): + from modules import shared + if len(shared.opts.torchao_quantization) > 0 and shared.opts.torchao_quantization_mode == 'pre' and allow_ao: + if 'Model' in shared.opts.torchao_quantization: + load_torchao() + if ao is None: + return kwargs + diffusers.utils.import_utils.is_torchao_available = lambda: True + ao_config = diffusers.TorchAoConfig(shared.opts.torchao_quantization_type) + shared.log.debug(f'Quantization: module=all type=torchao dtype={shared.opts.torchao_quantization_type}') + if kwargs is None: + return ao_config + else: + kwargs['quantization_config'] = ao_config + return kwargs + return kwargs + + +def load_torchao(msg='', silent=False): + global ao # pylint: disable=global-statement + if ao is not None: + return ao + install('torchao==0.7.0', quiet=True) + try: + import torchao + ao = torchao + fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access + log.debug(f'Quantization: type=torchao version={ao.__version__} fn={fn}') # pylint: disable=protected-access + return ao + except Exception as e: + if len(msg) > 0: + log.error(f"{msg} failed to import torchao: {e}") + ao = None + if not silent: + raise + return None + + def load_bnb(msg='', silent=False): global bnb # pylint: disable=global-statement if bnb is not None: return bnb - install('bitsandbytes', quiet=True) + install('bitsandbytes==0.45.0', quiet=True) try: import bitsandbytes bnb = bitsandbytes @@ -51,15 +92,18 @@ def load_bnb(msg='', silent=False): def load_quanto(msg='', silent=False): + from modules import shared global quanto # pylint: disable=global-statement if quanto is not None: return quanto - install('optimum-quanto', quiet=True) + install('optimum-quanto==0.2.6', quiet=True) try: from optimum import quanto as optimum_quanto # pylint: disable=no-name-in-module quanto = optimum_quanto fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access log.debug(f'Quantization: type=quanto version={quanto.__version__} fn={fn}') # pylint: disable=protected-access + if shared.opts.diffusers_offload_mode != 'none': + shared.log.error(f'Quantization: type=quanto offload={shared.opts.diffusers_offload_mode} not supported') return quanto except Exception as e: if len(msg) > 0: diff --git a/modules/model_sana.py b/modules/model_sana.py new file mode 100644 index 000000000..b9f56c7c6 --- /dev/null +++ b/modules/model_sana.py @@ -0,0 +1,79 @@ +import os +import time +import torch +import diffusers +import transformers +from modules import shared, sd_models, devices, modelloader, model_quant + + +def load_quants(kwargs, repo_id, cache_dir): + if len(shared.opts.bnb_quantization) > 0: + quant_args = {} + quant_args = model_quant.create_bnb_config(quant_args) + quant_args = model_quant.create_ao_config(quant_args) + load_args = kwargs.copy() + if not quant_args: + return kwargs + model_quant.load_bnb(f'Load model: type=SD3 quant={quant_args} args={load_args}') + if 'Model' in shared.opts.bnb_quantization and 'transformer' not in kwargs: + kwargs['transformer'] = diffusers.models.SanaTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, **load_args, **quant_args) + shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}') + if 'Text Encoder' in shared.opts.bnb_quantization and 'text_encoder_3' not in kwargs: + kwargs['text_encoder_3'] = transformers.AutoModelForCausalLM.from_pretrained(repo_id, subfolder="text_encoder", cache_dir=cache_dir, **load_args, **quant_args) + shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}') + return kwargs + + +def load_sana(checkpoint_info, kwargs={}): + modelloader.hf_login() + + fn = checkpoint_info if isinstance(checkpoint_info, str) else checkpoint_info.path + repo_id = sd_models.path_to_repo(fn) + kwargs.pop('load_connected_pipeline', None) + kwargs.pop('safety_checker', None) + kwargs.pop('requires_safety_checker', None) + kwargs.pop('torch_dtype', None) + + if not repo_id.endswith('_diffusers'): + repo_id = f'{repo_id}_diffusers' + if devices.dtype == torch.bfloat16 and 'BF16' not in repo_id: + repo_id = repo_id.replace('_diffusers', '_BF16_diffusers') + + if 'Sana_1600M' in repo_id: + if devices.dtype == torch.bfloat16 or 'BF16' in repo_id: + if 'BF16' not in repo_id: + repo_id = repo_id.replace('_diffusers', '_BF16_diffusers') + kwargs['variant'] = 'bf16' + kwargs['torch_dtype'] = devices.dtype + else: + kwargs['variant'] = 'fp16' + if 'Sana_600M' in repo_id: + kwargs['variant'] = 'fp16' + + if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)): + kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir) + # kwargs = model_quant.create_bnb_config(kwargs) + # kwargs = model_quant.create_ao_config(kwargs) + shared.log.debug(f'Load model: type=Sana repo="{repo_id}" args={list(kwargs)}') + t0 = time.time() + pipe = diffusers.SanaPipeline.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs) + if devices.dtype == torch.bfloat16 or devices.dtype == torch.float32: + if 'transformer' not in kwargs: + pipe.transformer = pipe.transformer.to(dtype=devices.dtype) + if 'text_encoder' not in kwargs: + pipe.text_encoder = pipe.text_encoder.to(dtype=devices.dtype) + pipe.vae = pipe.vae.to(dtype=devices.dtype) + if devices.dtype == torch.float16: + if 'transformer' not in kwargs: + pipe.transformer = pipe.transformer.to(dtype=devices.dtype) + if 'text_encoder' not in kwargs: + pipe.text_encoder = pipe.text_encoder.to(dtype=torch.float32) # gemma2 does not support fp16 + pipe.vae = pipe.vae.to(dtype=torch.float32) # dc-ae often overflows in fp16 + if shared.opts.diffusers_eval: + pipe.text_encoder.eval() + pipe.transformer.eval() + t1 = time.time() + shared.log.debug(f'Load model: type=Sana target={devices.dtype} te={pipe.text_encoder.dtype} transformer={pipe.transformer.dtype} vae={pipe.vae.dtype} time={t1-t0:.2f}') + + devices.torch_gc() + return pipe diff --git a/modules/model_sd3.py b/modules/model_sd3.py index b9d579085..1834c573e 100644 --- a/modules/model_sd3.py +++ b/modules/model_sd3.py @@ -1,7 +1,7 @@ import os import diffusers import transformers -from modules import shared, devices, sd_models, sd_unet, model_te, model_quant, model_tools +from modules import shared, devices, sd_models, sd_unet, model_quant, model_tools def load_overrides(kwargs, cache_dir): @@ -13,7 +13,9 @@ def load_overrides(kwargs, cache_dir): sd_unet.loaded_unet = shared.opts.sd_unet shared.log.debug(f'Load model: type=SD3 unet="{shared.opts.sd_unet}" fmt=safetensors') elif fn.endswith('.gguf'): - kwargs = load_gguf(kwargs, fn) + from modules import ggml + # kwargs = load_gguf(kwargs, fn) + kwargs['transformer'] = ggml.load_gguf(fn, cls=diffusers.SD3Transformer2DModel, compute_dtype=devices.dtype) sd_unet.loaded_unet = shared.opts.sd_unet shared.log.debug(f'Load model: type=SD3 unet="{shared.opts.sd_unet}" fmt=gguf') except Exception as e: @@ -51,19 +53,17 @@ def load_overrides(kwargs, cache_dir): def load_quants(kwargs, repo_id, cache_dir): if len(shared.opts.bnb_quantization) > 0: - model_quant.load_bnb('Load model: type=SD3') - bnb_config = diffusers.BitsAndBytesConfig( - load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'], - load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'], - bnb_4bit_quant_storage=shared.opts.bnb_quantization_storage, - bnb_4bit_quant_type=shared.opts.bnb_quantization_type, - bnb_4bit_compute_dtype=devices.dtype - ) + quant_args = {} + quant_args = model_quant.create_bnb_config(quant_args) + quant_args = model_quant.create_ao_config(quant_args) + if not quant_args: + return kwargs + model_quant.load_bnb(f'Load model: type=SD3 quant={quant_args}') if 'Model' in shared.opts.bnb_quantization and 'transformer' not in kwargs: - kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype) + kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args) shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}') if 'Text Encoder' in shared.opts.bnb_quantization and 'text_encoder_3' not in kwargs: - kwargs['text_encoder_3'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_3", variant='fp16', cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype) + kwargs['text_encoder_3'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_3", variant='fp16', cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args) shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}') return kwargs @@ -92,8 +92,9 @@ def load_missing(kwargs, fn, cache_dir): return kwargs +""" def load_gguf(kwargs, fn): - model_te.install_gguf() + ggml.install_gguf() from accelerate import init_empty_weights from diffusers.loaders.single_file_utils import convert_sd3_transformer_checkpoint_to_diffusers from modules import ggml, sd_hijack_accelerate @@ -110,10 +111,12 @@ def load_gguf(kwargs, fn): continue applied += 1 sd_hijack_accelerate.hijack_set_module_tensor_simple(transformer, tensor_name=param_name, value=param, device=0) + transformer.gguf = 'gguf' state_dict[param_name] = None shared.log.debug(f'Load model: type=Unet/Transformer applied={applied} skipped={skipped} stats={stats} compute={devices.dtype}') kwargs['transformer'] = transformer return kwargs +""" def load_sd3(checkpoint_info, cache_dir=None, config=None): @@ -127,7 +130,7 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None): kwargs = {} kwargs = load_overrides(kwargs, cache_dir) - if fn is None or not os.path.exists(fn): + if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)): kwargs = load_quants(kwargs, repo_id, cache_dir) loader = diffusers.StableDiffusion3Pipeline.from_pretrained @@ -141,7 +144,9 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None): # kwargs = load_missing(kwargs, fn, cache_dir) repo_id = fn elif fn.endswith('.gguf'): - kwargs = load_gguf(kwargs, fn) + from modules import ggml + kwargs['transformer'] = ggml.load_gguf(fn, cls=diffusers.SD3Transformer2DModel, compute_dtype=devices.dtype) + # kwargs = load_gguf(kwargs, fn) kwargs = load_missing(kwargs, fn, cache_dir) kwargs['variant'] = 'fp16' else: @@ -150,6 +155,7 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None): shared.log.debug(f'Load model: type=SD3 kwargs={list(kwargs)} repo="{repo_id}"') kwargs = model_quant.create_bnb_config(kwargs) + kwargs = model_quant.create_ao_config(kwargs) pipe = loader( repo_id, torch_dtype=devices.dtype, diff --git a/modules/model_stablecascade.py b/modules/model_stablecascade.py index 6c23ea00a..2a7739e55 100644 --- a/modules/model_stablecascade.py +++ b/modules/model_stablecascade.py @@ -187,8 +187,7 @@ def __call__( callback_on_step_end=None, callback_on_step_end_tensor_inputs=["latents"], ): - if shared.opts.diffusers_offload_mode == "balanced": - shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) # 0. Define commonly used variables self.guidance_scale = guidance_scale self.do_classifier_free_guidance = self.guidance_scale > 1 @@ -330,14 +329,11 @@ def __call__( elif output_type == "pil": images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work images = self.numpy_to_pil(images) + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) else: images = latents - # Offload all models - if shared.opts.diffusers_offload_mode == "balanced": - shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) - else: - self.maybe_free_model_hooks() + self.maybe_free_model_hooks() if not return_dict: return images diff --git a/modules/model_te.py b/modules/model_te.py index 606cdb86d..16bb6d222 100644 --- a/modules/model_te.py +++ b/modules/model_te.py @@ -12,20 +12,6 @@ loaded_te = None -def install_gguf(): - # pip install git+https://github.com/junejae/transformers@feature/t5-gguf - install('gguf', quiet=True) - # https://github.com/ggerganov/llama.cpp/issues/9566 - import gguf - scripts_dir = os.path.join(os.path.dirname(gguf.__file__), '..', 'scripts') - if os.path.exists(scripts_dir): - os.rename(scripts_dir, scripts_dir + '_gguf') - # monkey patch transformers so they detect gguf pacakge correctly - import importlib - transformers.utils.import_utils._is_gguf_available = True # pylint: disable=protected-access - transformers.utils.import_utils._gguf_version = importlib.metadata.version('gguf') # pylint: disable=protected-access - - def load_t5(name=None, cache_dir=None): global loaded_te # pylint: disable=global-statement if name is None: @@ -34,8 +20,9 @@ def load_t5(name=None, cache_dir=None): modelloader.hf_login() repo_id = 'stabilityai/stable-diffusion-3-medium-diffusers' fn = te_dict.get(name) if name in te_dict else None - if fn is not None and 'gguf' in name.lower(): - install_gguf() + if fn is not None and name.lower().endswith('gguf'): + from modules import ggml + ggml.install_gguf() with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f: t5_config = transformers.T5Config(**json.load(f)) t5 = transformers.T5EncoderModel.from_pretrained(None, gguf_file=fn, config=t5_config, device_map="auto", cache_dir=cache_dir, torch_dtype=devices.dtype) @@ -52,7 +39,8 @@ def load_t5(name=None, cache_dir=None): if torch.is_floating_point(param) and not is_param_float8_e4m3fn: param = param.to(devices.dtype) set_module_tensor_to_device(t5, param_name, device=0, value=param) - t5.eval() + if shared.opts.diffusers_eval: + t5.eval() if t5.dtype != devices.dtype: try: t5 = t5.to(dtype=devices.dtype) diff --git a/modules/model_tools.py b/modules/model_tools.py index 07cd61b6e..fdeda5c2a 100644 --- a/modules/model_tools.py +++ b/modules/model_tools.py @@ -69,11 +69,13 @@ def load_modules(repo_id: str, params: dict): subfolder = 'text_encoder_2' if cls == transformers.T5EncoderModel: # t5-xxl subfolder = 'text_encoder_3' - kwargs['quantization_config'] = model_quant.create_bnb_config() + kwargs = model_quant.create_bnb_config(kwargs) + kwargs = model_quant.create_ao_config(kwargs) kwargs['variant'] = 'fp16' if cls == diffusers.SD3Transformer2DModel: subfolder = 'transformer' - kwargs['quantization_config'] = model_quant.create_bnb_config() + kwargs = model_quant.create_bnb_config(kwargs) + kwargs = model_quant.create_ao_config(kwargs) if subfolder is None: continue shared.log.debug(f'Load: module={name} class={cls.__name__} repo={repo_id} location={subfolder}') diff --git a/modules/modeldata.py b/modules/modeldata.py index 604ff4623..4b7ec1776 100644 --- a/modules/modeldata.py +++ b/modules/modeldata.py @@ -3,6 +3,45 @@ from modules import shared, errors +def get_model_type(pipe): + name = pipe.__class__.__name__ + if not shared.native: + model_type = 'ldm' + elif "StableDiffusion3" in name: + model_type = 'sd3' + elif "StableDiffusionXL" in name: + model_type = 'sdxl' + elif "StableDiffusion" in name: + model_type = 'sd' + elif "LatentConsistencyModel" in name: + model_type = 'sd' # lcm is compatible with sd + elif "InstaFlowPipeline" in name: + model_type = 'sd' # instaflow is compatible with sd + elif "AnimateDiffPipeline" in name: + model_type = 'sd' # animatediff is compatible with sd + elif "Kandinsky" in name: + model_type = 'kandinsky' + elif "HunyuanDiT" in name: + model_type = 'hunyuandit' + elif "Cascade" in name: + model_type = 'sc' + elif "AuraFlow" in name: + model_type = 'auraflow' + elif "Flux" in name: + model_type = 'f1' + elif "Lumina" in name: + model_type = 'lumina' + elif "OmniGen" in name: + model_type = 'omnigen' + elif "CogVideo" in name: + model_type = 'cogvideox' + elif "Sana" in name: + model_type = 'sana' + else: + model_type = name + return model_type + + class ModelData: def __init__(self): self.sd_model = None @@ -82,36 +121,7 @@ def sd_model_type(self): if modules.sd_models.model_data.sd_model is None: model_type = 'none' return model_type - if not shared.native: - model_type = 'ldm' - elif "StableDiffusion3" in self.sd_model.__class__.__name__: - model_type = 'sd3' - elif "StableDiffusionXL" in self.sd_model.__class__.__name__: - model_type = 'sdxl' - elif "StableDiffusion" in self.sd_model.__class__.__name__: - model_type = 'sd' - elif "LatentConsistencyModel" in self.sd_model.__class__.__name__: - model_type = 'sd' # lcm is compatible with sd - elif "InstaFlowPipeline" in self.sd_model.__class__.__name__: - model_type = 'sd' # instaflow is compatible with sd - elif "AnimateDiffPipeline" in self.sd_model.__class__.__name__: - model_type = 'sd' # animatediff is compatible with sd - elif "Kandinsky" in self.sd_model.__class__.__name__: - model_type = 'kandinsky' - elif "HunyuanDiT" in self.sd_model.__class__.__name__: - model_type = 'hunyuandit' - elif "Cascade" in self.sd_model.__class__.__name__: - model_type = 'sc' - elif "AuraFlow" in self.sd_model.__class__.__name__: - model_type = 'auraflow' - elif "Flux" in self.sd_model.__class__.__name__: - model_type = 'f1' - elif "OmniGen" in self.sd_model.__class__.__name__: - model_type = 'omnigen' - elif "CogVideo" in self.sd_model.__class__.__name__: - model_type = 'cogvideox' - else: - model_type = self.sd_model.__class__.__name__ + model_type = get_model_type(self.sd_model) except Exception: model_type = 'unknown' return model_type @@ -123,18 +133,7 @@ def sd_refiner_type(self): if modules.sd_models.model_data.sd_refiner is None: model_type = 'none' return model_type - if not shared.native: - model_type = 'ldm' - elif "StableDiffusion3" in self.sd_refiner.__class__.__name__: - model_type = 'sd3' - elif "StableDiffusionXL" in self.sd_refiner.__class__.__name__: - model_type = 'sdxl' - elif "StableDiffusion" in self.sd_refiner.__class__.__name__: - model_type = 'sd' - elif "Kandinsky" in self.sd_refiner.__class__.__name__: - model_type = 'kandinsky' - else: - model_type = self.sd_refiner.__class__.__name__ + model_type = get_model_type(self.sd_refiner) except Exception: model_type = 'unknown' return model_type diff --git a/modules/modelloader.py b/modules/modelloader.py index ce36a739b..b022b4fc6 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -267,7 +267,7 @@ def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config def load_diffusers_models(clear=True): excluded_models = [] - t0 = time.time() + # t0 = time.time() place = shared.opts.diffusers_dir if place is None or len(place) == 0 or not os.path.isdir(place): place = os.path.join(models_path, 'Diffusers') @@ -316,7 +316,7 @@ def load_diffusers_models(clear=True): debug(f'Error analyzing diffusers model: "{folder}" {e}') except Exception as e: shared.log.error(f"Error listing diffusers: {place} {e}") - shared.log.debug(f'Scanning diffusers cache: folder="{place}" items={len(list(diffuser_repos))} time={time.time()-t0:.2f}') + # shared.log.debug(f'Scanning diffusers cache: folder="{place}" items={len(list(diffuser_repos))} time={time.time()-t0:.2f}') return diffuser_repos @@ -326,6 +326,9 @@ def find_diffuser(name: str, full=False): return [repo[0]['name']] hf_api = hf.HfApi() models = list(hf_api.list_models(model_name=name, library=['diffusers'], full=True, limit=20, sort="downloads", direction=-1)) + if len(models) == 0: + models = list(hf_api.list_models(model_name=name, full=True, limit=20, sort="downloads", direction=-1)) # widen search + models = [m for m in models if m.id.startswith(name)] # filter exact shared.log.debug(f'Searching diffusers models: {name} {len(models) > 0}') if len(models) > 0: if not full: diff --git a/modules/omnigen/utils.py b/modules/omnigen/utils.py index 5483d6eab..0304e1732 100644 --- a/modules/omnigen/utils.py +++ b/modules/omnigen/utils.py @@ -25,7 +25,6 @@ def update_ema(ema_model, model, decay=0.9999): """ ema_params = dict(ema_model.named_parameters()) for name, param in model.named_parameters(): - # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) diff --git a/modules/onnx_impl/pipelines/__init__.py b/modules/onnx_impl/pipelines/__init__.py index ca1ddd2f7..a11b07fc7 100644 --- a/modules/onnx_impl/pipelines/__init__.py +++ b/modules/onnx_impl/pipelines/__init__.py @@ -241,7 +241,7 @@ def run_olive(self, submodels: List[str], in_dir: os.PathLike, out_dir: os.PathL for i in range(len(flow)): flow[i] = flow[i].replace("AutoExecutionProvider", shared.opts.onnx_execution_provider) olive_config["input_model"]["config"]["model_path"] = os.path.abspath(os.path.join(in_dir, submodel, "model.onnx")) - olive_config["systems"]["local_system"]["config"]["accelerators"][0]["device"] = "cpu" if shared.opts.onnx_execution_provider == ExecutionProvider.CPU else "gpu" # TODO: npu + olive_config["systems"]["local_system"]["config"]["accelerators"][0]["device"] = "cpu" if shared.opts.onnx_execution_provider == ExecutionProvider.CPU else "gpu" olive_config["systems"]["local_system"]["config"]["accelerators"][0]["execution_providers"] = [shared.opts.onnx_execution_provider] for pass_key in olive_config["passes"]: diff --git a/modules/pag/__init__.py b/modules/pag/__init__.py index 29cdee8ca..a72f7825d 100644 --- a/modules/pag/__init__.py +++ b/modules/pag/__init__.py @@ -12,33 +12,41 @@ def apply(p: processing.StableDiffusionProcessing): # pylint: disable=arguments- global orig_pipeline # pylint: disable=global-statement if not shared.native: return None - c = shared.sd_model.__class__ if shared.sd_loaded else None - if c == StableDiffusionPAGPipeline or c == StableDiffusionXLPAGPipeline: - unapply() + cls = shared.sd_model.__class__ if shared.sd_loaded else None + if cls == StableDiffusionPAGPipeline or cls == StableDiffusionXLPAGPipeline: + cls = unapply() if p.pag_scale == 0: return - if sd_models.get_diffusers_task(shared.sd_model) != sd_models.DiffusersTaskType.TEXT_2_IMAGE: - shared.log.warning(f'PAG: pipeline={c} not implemented') - return None - if detect.is_sd15(c): + if 'PAG' in cls.__name__: + pass + elif detect.is_sd15(cls): + if sd_models.get_diffusers_task(shared.sd_model) != sd_models.DiffusersTaskType.TEXT_2_IMAGE: + shared.log.warning(f'PAG: pipeline={cls.__name__} not implemented') + return None orig_pipeline = shared.sd_model shared.sd_model = sd_models.switch_pipe(StableDiffusionPAGPipeline, shared.sd_model) - elif detect.is_sdxl(c): + elif detect.is_sdxl(cls): + if sd_models.get_diffusers_task(shared.sd_model) != sd_models.DiffusersTaskType.TEXT_2_IMAGE: + shared.log.warning(f'PAG: pipeline={cls.__name__} not implemented') + return None orig_pipeline = shared.sd_model shared.sd_model = sd_models.switch_pipe(StableDiffusionXLPAGPipeline, shared.sd_model) + elif detect.is_f1(cls): + p.task_args['true_cfg_scale'] = p.pag_scale else: - shared.log.warning(f'PAG: pipeline={c} required={StableDiffusionPipeline.__name__}') + shared.log.warning(f'PAG: pipeline={cls.__name__} required={StableDiffusionPipeline.__name__}') return None p.task_args['pag_scale'] = p.pag_scale p.task_args['pag_adaptive_scaling'] = p.pag_adaptive + p.task_args['pag_adaptive_scale'] = p.pag_adaptive pag_applied_layers = shared.opts.pag_apply_layers pag_applied_layers_index = pag_applied_layers.split() if len(pag_applied_layers) > 0 else [] pag_applied_layers_index = [p.strip() for p in pag_applied_layers_index] p.task_args['pag_applied_layers_index'] = pag_applied_layers_index if len(pag_applied_layers_index) > 0 else ['m0'] # Available layers: d[0-5], m[0], u[0-8] p.extra_generation_params["PAG scale"] = p.pag_scale p.extra_generation_params["PAG adaptive"] = p.pag_adaptive - shared.log.debug(f'{c}: args={p.task_args}') + # shared.log.debug(f'{c}: args={p.task_args}') def unapply(): @@ -46,3 +54,4 @@ def unapply(): if orig_pipeline is not None: shared.sd_model = orig_pipeline orig_pipeline = None + return shared.sd_model.__class__ diff --git a/modules/pag/pipe_sd.py b/modules/pag/pipe_sd.py index 16f9b5319..11f4fb0cf 100644 --- a/modules/pag/pipe_sd.py +++ b/modules/pag/pipe_sd.py @@ -104,7 +104,6 @@ def __call__( value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states_org = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -219,7 +218,6 @@ def __call__( value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states_org = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) diff --git a/modules/pag/pipe_sdxl.py b/modules/pag/pipe_sdxl.py index 82ae06c07..3a47af3e5 100644 --- a/modules/pag/pipe_sdxl.py +++ b/modules/pag/pipe_sdxl.py @@ -124,7 +124,6 @@ def __call__( value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states_org = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -239,7 +238,6 @@ def __call__( value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states_org = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) diff --git a/modules/postprocess/yolo.py b/modules/postprocess/yolo.py index 0e733d419..8fc203280 100644 --- a/modules/postprocess/yolo.py +++ b/modules/postprocess/yolo.py @@ -72,7 +72,7 @@ def predict( imgsz: int = 640, half: bool = True, device = devices.device, - augment: bool = True, + augment: bool = shared.opts.detailer_augment, agnostic: bool = False, retina: bool = False, mask: bool = True, @@ -300,7 +300,8 @@ def restore(self, np_image, p: processing.StableDiffusionProcessing = None): # combined.save('/tmp/item.png') p.image_mask = Image.fromarray(p.image_mask) - shared.log.debug(f'Detailer processed: models={models_used}') + if len(models_used) > 0: + shared.log.debug(f'Detailer processed: models={models_used}') return np_image def ui(self, tab: str): diff --git a/modules/processing.py b/modules/processing.py index 0d557e64e..c85938268 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -118,6 +118,9 @@ def js(self): def infotext(self, p: StableDiffusionProcessing, index): return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size) + def __str___(self): + return f'{self.__class__.__name__}: {self.__dict__}' + def process_images(p: StableDiffusionProcessing) -> Processed: timer.process.reset() @@ -176,6 +179,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: timer.process.record('pre') if shared.cmd_opts.profile: + timer.startup.profile = True + timer.process.profile = True with context_hypertile_vae(p), context_hypertile_unet(p): import torch.profiler # pylint: disable=redefined-outer-name activities=[torch.profiler.ProfilerActivity.CPU] @@ -286,7 +291,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: t0 = time.time() if not hasattr(p, 'skip_init'): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) - extra_network_data = None debug(f'Processing inner: args={vars(p)}') for n in range(p.n_iter): pag.apply(p) @@ -311,9 +315,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) if len(p.prompts) == 0: break - p.prompts, extra_network_data = extra_networks.parse_prompts(p.prompts) - if not p.disable_extra_networks: - extra_networks.activate(p, extra_network_data) + p.prompts, p.network_data = extra_networks.parse_prompts(p.prompts) + if not shared.native: + extra_networks.activate(p, p.network_data) if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner): p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) @@ -323,7 +327,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: processed = p.scripts.process_images(p) if processed is not None: samples = processed.images - infotexts = processed.infotexts + infotexts += processed.infotexts if samples is None: if not shared.native: from modules.processing_original import process_original @@ -353,7 +357,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: for i, sample in enumerate(samples): debug(f'Processing result: index={i+1}/{len(samples)} iteration={n+1}/{p.n_iter}') p.batch_index = i - if type(sample) == Image.Image: + if isinstance(sample, Image.Image) or (isinstance(sample, list) and isinstance(sample[0], Image.Image)): image = sample sample = np.array(sample) else: @@ -393,16 +397,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if shared.opts.mask_apply_overlay: image = apply_overlay(image, p.paste_to, i, p.overlay_images) - if len(infotexts) > i: - info = infotexts[i] + info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i, all_negative_prompts=p.negative_prompts) + infotexts.append(info) + if isinstance(image, list): + for img in image: + img.info["parameters"] = info + output_images = image else: - info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i, all_negative_prompts=p.negative_prompts) - infotexts.append(info) - image.info["parameters"] = info - output_images.append(image) + image.info["parameters"] = info + output_images.append(image) if shared.opts.samples_save and not p.do_not_save_samples and p.outpath_samples is not None: info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i) - images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=info, p=p) # main save image + if isinstance(image, list): + for img in image: + images.save_image(img, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=info, p=p) # main save image + else: + images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=info, p=p) # main save image if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([shared.opts.save_mask, shared.opts.save_mask_composite, shared.opts.return_mask, shared.opts.return_mask_composite]): image_mask = p.mask_for_overlay.convert('RGB') image1 = image.convert('RGBA').convert('RGBa') @@ -420,6 +430,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: timer.process.record('post') del samples + + if not shared.native: + extra_networks.deactivate(p, p.network_data) + devices.torch_gc() if hasattr(shared.sd_model, 'restore_pipeline') and shared.sd_model.restore_pipeline is not None: @@ -446,10 +460,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if shared.native: from modules import ipadapter - ipadapter.unapply(shared.sd_model) - - if not p.disable_extra_networks: - extra_networks.deactivate(p, extra_network_data) + ipadapter.unapply(shared.sd_model, unload=getattr(p, 'ip_adapter_unload', False)) if shared.opts.include_mask: if shared.opts.mask_apply_overlay and p.overlay_images is not None and len(p.overlay_images): @@ -475,5 +486,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner) and not (shared.state.interrupted or shared.state.skipped): p.scripts.postprocess(p, processed) timer.process.record('post') - shared.log.info(f'Processed: images={len(output_images)} its={(p.steps * len(output_images)) / (t1 - t0):.2f} time={t1-t0:.2f} timers={timer.process.dct(min_time=0.02)} memory={memstats.memory_stats()}') + if not p.disable_extra_networks: + shared.log.info(f'Processed: images={len(output_images)} its={(p.steps * len(output_images)) / (t1 - t0):.2f} time={t1-t0:.2f} timers={timer.process.dct()} memory={memstats.memory_stats()}') + + devices.torch_gc(force=True, reason='final') return processed diff --git a/modules/processing_args.py b/modules/processing_args.py index ff766ec04..5ca6cfcd3 100644 --- a/modules/processing_args.py +++ b/modules/processing_args.py @@ -6,13 +6,14 @@ import inspect import torch import numpy as np -from modules import shared, errors, sd_models, processing, processing_vae, processing_helpers, sd_hijack_hypertile, prompt_parser_diffusers, timer +from modules import shared, errors, sd_models, processing, processing_vae, processing_helpers, sd_hijack_hypertile, prompt_parser_diffusers, timer, extra_networks from modules.processing_callbacks import diffusers_callback_legacy, diffusers_callback, set_callbacks_p from modules.processing_helpers import resize_hires, fix_prompts, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, get_generator, set_latents, apply_circular # pylint: disable=unused-import from modules.api import helpers -debug = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None +debug_enabled = os.environ.get('SD_DIFFUSERS_DEBUG', None) +debug_log = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None def task_specific_kwargs(p, model): @@ -22,7 +23,7 @@ def task_specific_kwargs(p, model): if isinstance(p.init_images[0], str): p.init_images = [helpers.decode_base64_to_image(i, quiet=True) for i in p.init_images] p.init_images = [i.convert('RGB') if i.mode != 'RGB' else i for i in p.init_images] - if sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE or len(getattr(p, 'init_images', [])) == 0 and not is_img2img_model: + if (sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE or len(getattr(p, 'init_images', [])) == 0) and not is_img2img_model: p.ops.append('txt2img') if hasattr(p, 'width') and hasattr(p, 'height'): task_args = { @@ -93,12 +94,14 @@ def task_specific_kwargs(p, model): 'target_subject_category': getattr(p, 'prompt', '').split()[-1], 'output_type': 'pil', } - debug(f'Diffusers task specific args: {task_args}') + if debug_enabled: + debug_log(f'Diffusers task specific args: {task_args}') return task_args def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2: typing.Optional[list]=None, negative_prompts_2: typing.Optional[list]=None, desc:str='', **kwargs): t0 = time.time() + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) apply_circular(p.tiling, model) if hasattr(model, "set_progress_bar_config"): model.set_progress_bar_config(bar_format='Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + desc, ncols=80, colour='#327fba') @@ -108,7 +111,8 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2 signature = inspect.signature(type(model).__call__, follow_wrapped=True) possible = list(signature.parameters) - debug(f'Diffusers pipeline possible: {possible}') + if debug_enabled: + debug_log(f'Diffusers pipeline possible: {possible}') prompts, negative_prompts, prompts_2, negative_prompts_2 = fix_prompts(prompts, negative_prompts, prompts_2, negative_prompts_2) steps = kwargs.get("num_inference_steps", None) or len(getattr(p, 'timesteps', ['1'])) clip_skip = kwargs.pop("clip_skip", 1) @@ -130,12 +134,13 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2 else: prompt_parser_diffusers.embedder = None + extra_networks.activate(p, include=['text_encoder', 'text_encoder_2', 'text_encoder_3']) if 'prompt' in possible: if 'OmniGen' in model.__class__.__name__: prompts = [p.replace('|image|', '<|image_1|>') for p in prompts] if hasattr(model, 'text_encoder') and hasattr(model, 'tokenizer') and 'prompt_embeds' in possible and prompt_parser_diffusers.embedder is not None: args['prompt_embeds'] = prompt_parser_diffusers.embedder('prompt_embeds') - if 'StableCascade' in model.__class__.__name__ and len(getattr(p, 'negative_pooleds', [])) > 0: + if 'StableCascade' in model.__class__.__name__ and prompt_parser_diffusers.embedder is not None: args['prompt_embeds_pooled'] = prompt_parser_diffusers.embedder('positive_pooleds').unsqueeze(0) elif 'XL' in model.__class__.__name__ and prompt_parser_diffusers.embedder is not None: args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds') @@ -159,6 +164,13 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2 args['negative_prompt'] = negative_prompts[0] else: args['negative_prompt'] = negative_prompts + if 'complex_human_instruction' in possible: + chi = any(len(p) < 300 for p in prompts) + p.extra_generation_params["CHI"] = chi + if not chi: + args['complex_human_instruction'] = None + if prompt_parser_diffusers.embedder is not None and not prompt_parser_diffusers.embedder.scheduled_prompt: # not scheduled so we dont need it anymore + prompt_parser_diffusers.embedder = None if 'clip_skip' in possible and parser == 'fixed': if clip_skip == 1: @@ -181,6 +193,21 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2 shared.log.error(f'Sampler timesteps: {e}') else: shared.log.warning(f'Sampler: sampler={model.scheduler.__class__.__name__} timesteps not supported') + if 'sigmas' in possible: + sigmas = re.split(',| ', shared.opts.schedulers_timesteps) + sigmas = [float(x)/1000.0 for x in sigmas if x.isdigit()] + if len(sigmas) > 0: + if hasattr(model.scheduler, 'set_timesteps') and "sigmas" in set(inspect.signature(model.scheduler.set_timesteps).parameters.keys()): + try: + args['sigmas'] = sigmas + p.steps = len(sigmas) + p.timesteps = sigmas + steps = p.steps + shared.log.debug(f'Sampler: steps={len(sigmas)} sigmas={sigmas}') + except Exception as e: + shared.log.error(f'Sampler sigmas: {e}') + else: + shared.log.warning(f'Sampler: sampler={model.scheduler.__class__.__name__} sigmas not supported') if hasattr(model, 'scheduler') and hasattr(model.scheduler, 'noise_sampler_seed') and hasattr(model.scheduler, 'noise_sampler'): model.scheduler.noise_sampler = None # noise needs to be reset instead of using cached values @@ -248,14 +275,16 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2 if arg in possible: args[arg] = task_kwargs[arg] task_args = getattr(p, 'task_args', {}) - debug(f'Diffusers task args: {task_args}') + if debug_enabled: + debug_log(f'Diffusers task args: {task_args}') for k, v in task_args.items(): if k in possible: args[k] = v else: - debug(f'Diffusers unknown task args: {k}={v}') + debug_log(f'Diffusers unknown task args: {k}={v}') cross_attention_args = getattr(p, 'cross_attention_kwargs', {}) - debug(f'Diffusers cross-attention args: {cross_attention_args}') + if debug_enabled: + debug_log(f'Diffusers cross-attention args: {cross_attention_args}') for k, v in cross_attention_args.items(): if args.get('cross_attention_kwargs', None) is None: args['cross_attention_kwargs'] = {} @@ -273,7 +302,7 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2 # handle implicit controlnet if 'control_image' in possible and 'control_image' not in args and 'image' in args: - debug('Diffusers: set control image') + debug_log('Diffusers: set control image') args['control_image'] = args['image'] sd_hijack_hypertile.hypertile_set(p, hr=len(getattr(p, 'init_images', [])) > 0) @@ -291,11 +320,14 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2 clean['negative_prompt'] = len(clean['negative_prompt']) clean.pop('generator', None) clean['parser'] = parser - for k, v in clean.items(): + for k, v in clean.copy().items(): if isinstance(v, torch.Tensor) or isinstance(v, np.ndarray): clean[k] = v.shape if isinstance(v, list) and len(v) > 0 and (isinstance(v[0], torch.Tensor) or isinstance(v[0], np.ndarray)): clean[k] = [x.shape for x in v] + if not debug_enabled and k.endswith('_embeds'): + del clean[k] + clean['prompt'] = 'embeds' shared.log.debug(f'Diffuser pipeline: {model.__class__.__name__} task={sd_models.get_diffusers_task(model)} batch={p.iteration + 1}/{p.n_iter}x{p.batch_size} set={clean}') if p.hdr_clamp or p.hdr_maximize or p.hdr_brightness != 0 or p.hdr_color != 0 or p.hdr_sharpen != 0: @@ -309,5 +341,6 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2 if shared.cmd_opts.profile: t1 = time.time() shared.log.debug(f'Profile: pipeline args: {t1-t0:.2f}') - debug(f'Diffusers pipeline args: {args}') + if debug_enabled: + debug_log(f'Diffusers pipeline args: {args}') return args diff --git a/modules/processing_callbacks.py b/modules/processing_callbacks.py index 52ea3e575..0b4c7dfe1 100644 --- a/modules/processing_callbacks.py +++ b/modules/processing_callbacks.py @@ -5,14 +5,17 @@ import numpy as np from modules import shared, processing_correction, extra_networks, timer, prompt_parser_diffusers + p = None -debug_callback = shared.log.trace if os.environ.get('SD_CALLBACK_DEBUG', None) is not None else lambda *args, **kwargs: None +debug = os.environ.get('SD_CALLBACK_DEBUG', None) is not None +debug_callback = shared.log.trace if debug else lambda *args, **kwargs: None def set_callbacks_p(processing): global p # pylint: disable=global-statement p = processing + def prompt_callback(step, kwargs): if prompt_parser_diffusers.embedder is None or 'prompt_embeds' not in kwargs: return kwargs @@ -27,6 +30,7 @@ def prompt_callback(step, kwargs): debug_callback(f"Callback: {e}") return kwargs + def diffusers_callback_legacy(step: int, timestep: int, latents: typing.Union[torch.FloatTensor, np.ndarray]): if p is None: return @@ -50,7 +54,8 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {} if p is None: return kwargs latents = kwargs.get('latents', None) - debug_callback(f'Callback: step={step} timestep={timestep} latents={latents.shape if latents is not None else None} kwargs={list(kwargs)}') + if debug: + debug_callback(f'Callback: step={step} timestep={timestep} latents={latents.shape if latents is not None else None} kwargs={list(kwargs)}') order = getattr(pipe.scheduler, "order", 1) if hasattr(pipe, 'scheduler') else 1 shared.state.sampling_step = step // order if shared.state.interrupted or shared.state.skipped: @@ -61,13 +66,13 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {} if shared.state.interrupted or shared.state.skipped: raise AssertionError('Interrupted...') time.sleep(0.1) - if hasattr(p, "stepwise_lora"): - extra_networks.activate(p, p.extra_network_data, step=step) + if hasattr(p, "stepwise_lora") and shared.native: + extra_networks.activate(p, step=step) if latents is None: return kwargs elif shared.opts.nan_skip: assert not torch.isnan(latents[..., 0, 0]).all(), f'NaN detected at step {step}: Skipping...' - if len(getattr(p, 'ip_adapter_names', [])) > 0: + if len(getattr(p, 'ip_adapter_names', [])) > 0 and p.ip_adapter_names[0] != 'None': ip_adapter_scales = list(p.ip_adapter_scales) ip_adapter_starts = list(p.ip_adapter_starts) ip_adapter_ends = list(p.ip_adapter_ends) @@ -78,7 +83,7 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {} debug_callback(f"Callback: IP Adapter scales={ip_adapter_scales}") pipe.set_ip_adapter_scale(ip_adapter_scales) if step != getattr(pipe, 'num_timesteps', 0): - kwargs = processing_correction.correction_callback(p, timestep, kwargs) + kwargs = processing_correction.correction_callback(p, timestep, kwargs, initial=step == 0) kwargs = prompt_callback(step, kwargs) # monkey patch for diffusers callback issues if step == int(getattr(pipe, 'num_timesteps', 100) * p.cfg_end) and 'prompt_embeds' in kwargs and 'negative_prompt_embeds' in kwargs: if "PAG" in shared.sd_model.__class__.__name__: @@ -105,7 +110,5 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {} if shared.cmd_opts.profile and shared.profiler is not None: shared.profiler.step() t1 = time.time() - if 'callback' not in timer.process.records: - timer.process.records['callback'] = 0 - timer.process.records['callback'] += t1 - t0 + timer.process.add('callback', t1 - t0) return kwargs diff --git a/modules/processing_class.py b/modules/processing_class.py index 79f51576f..97e5f7b6f 100644 --- a/modules/processing_class.py +++ b/modules/processing_class.py @@ -31,8 +31,8 @@ def __init__(self, n_iter: int = 1, steps: int = 50, clip_skip: int = 1, - width: int = 512, - height: int = 512, + width: int = 1024, + height: int = 1024, # samplers sampler_index: int = None, # pylint: disable=unused-argument # used only to set sampler_name sampler_name: str = None, @@ -139,6 +139,7 @@ def __init__(self, self.negative_pooleds = [] self.disable_extra_networks = False self.iteration = 0 + self.network_data = {} # initializers self.prompt = prompt @@ -169,7 +170,7 @@ def __init__(self, self.image_cfg_scale = image_cfg_scale self.scale_by = scale_by self.mask = mask - self.image_mask = mask # TODO duplciate mask params + self.image_mask = mask # TODO processing: remove duplicate mask params self.latent_mask = latent_mask self.mask_blur = mask_blur self.inpainting_fill = inpainting_fill @@ -338,6 +339,9 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs def close(self): self.sampler = None # pylint: disable=attribute-defined-outside-init + def __str__(self): + return f'{self.__class__.__name__}: {self.__dict__}' + class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def __init__(self, **kwargs): @@ -347,8 +351,8 @@ def __init__(self, **kwargs): def init(self, all_prompts=None, all_seeds=None, all_subseeds=None): if shared.native: shared.sd_model = sd_models.set_diffuser_pipe(self.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE) - self.width = self.width or 512 - self.height = self.height or 512 + self.width = self.width or 1024 + self.height = self.height or 1024 if all_prompts is not None: self.all_prompts = all_prompts if all_seeds is not None: diff --git a/modules/processing_correction.py b/modules/processing_correction.py index e715d8c49..050fae889 100644 --- a/modules/processing_correction.py +++ b/modules/processing_correction.py @@ -7,9 +7,11 @@ import torch from modules import shared, sd_vae_taesd, devices + debug_enabled = os.environ.get('SD_HDR_DEBUG', None) is not None debug = shared.log.trace if debug_enabled else lambda *args, **kwargs: None debug('Trace: HDR') +skip_correction = False def sharpen_tensor(tensor, ratio=0): @@ -116,8 +118,15 @@ def correction(p, timestep, latent): return latent -def correction_callback(p, timestep, kwargs): - if not any([p.hdr_clamp, p.hdr_mode, p.hdr_maximize, p.hdr_sharpen, p.hdr_color, p.hdr_brightness, p.hdr_tint_ratio]): +def correction_callback(p, timestep, kwargs, initial: bool = False): + global skip_correction # pylint: disable=global-statement + if initial: + if not any([p.hdr_clamp, p.hdr_mode, p.hdr_maximize, p.hdr_sharpen, p.hdr_color, p.hdr_brightness, p.hdr_tint_ratio]): + skip_correction = True + return kwargs + else: + skip_correction = False + elif skip_correction: return kwargs latents = kwargs["latents"] if debug_enabled: diff --git a/modules/processing_diffusers.py b/modules/processing_diffusers.py index 2164134b1..33d875e95 100644 --- a/modules/processing_diffusers.py +++ b/modules/processing_diffusers.py @@ -4,10 +4,12 @@ import numpy as np import torch import torchvision.transforms.functional as TF -from modules import shared, devices, processing, sd_models, errors, sd_hijack_hypertile, processing_vae, sd_models_compile, hidiffusion, timer, modelstats +from PIL import Image +from modules import shared, devices, processing, sd_models, errors, sd_hijack_hypertile, processing_vae, sd_models_compile, hidiffusion, timer, modelstats, extra_networks from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled from modules.processing_args import set_pipeline_args from modules.onnx_impl import preprocess_pipeline as preprocess_onnx_pipeline, check_parameters_changed as olive_check_parameters_changed +from modules.lora import networks debug = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None @@ -75,7 +77,6 @@ def process_base(p: processing.StableDiffusionProcessing): clip_skip=p.clip_skip, desc='Base', ) - timer.process.record('args') shared.state.sampling_steps = base_args.get('prior_num_inference_steps', None) or p.steps or base_args.get('num_inference_steps', None) if shared.opts.scheduler_eta is not None and shared.opts.scheduler_eta > 0 and shared.opts.scheduler_eta < 1: p.extra_generation_params["Sampler Eta"] = shared.opts.scheduler_eta @@ -83,11 +84,13 @@ def process_base(p: processing.StableDiffusionProcessing): try: t0 = time.time() sd_models_compile.check_deepcache(enable=True) + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) sd_models.move_model(shared.sd_model, devices.device) if hasattr(shared.sd_model, 'unet'): sd_models.move_model(shared.sd_model.unet, devices.device) if hasattr(shared.sd_model, 'transformer'): sd_models.move_model(shared.sd_model.transformer, devices.device) + extra_networks.activate(p, exclude=['text_encoder', 'text_encoder_2']) hidiffusion.apply(p, shared.sd_model_type) # if 'image' in base_args: # base_args['image'] = set_latents(p) @@ -199,11 +202,6 @@ def process_hires(p: processing.StableDiffusionProcessing, output): if hasattr(shared.sd_model, "vae") and output.images is not None and len(output.images) > 0: output.images = processing_vae.vae_decode(latents=output.images, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.hr_upscale_to_x, height=p.hr_upscale_to_y) # controlnet cannnot deal with latent input p.task_args['image'] = output.images # replace so hires uses new output - sd_models.move_model(shared.sd_model, devices.device) - if hasattr(shared.sd_model, 'unet'): - sd_models.move_model(shared.sd_model.unet, devices.device) - if hasattr(shared.sd_model, 'transformer'): - sd_models.move_model(shared.sd_model.transformer, devices.device) update_sampler(p, shared.sd_model, second_pass=True) orig_denoise = p.denoising_strength p.denoising_strength = strength @@ -227,11 +225,20 @@ def process_hires(p: processing.StableDiffusionProcessing, output): shared.state.job = 'HiRes' shared.state.sampling_steps = hires_args.get('prior_num_inference_steps', None) or p.steps or hires_args.get('num_inference_steps', None) try: + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) + sd_models.move_model(shared.sd_model, devices.device) + if hasattr(shared.sd_model, 'unet'): + sd_models.move_model(shared.sd_model.unet, devices.device) + if hasattr(shared.sd_model, 'transformer'): + sd_models.move_model(shared.sd_model.transformer, devices.device) + if 'base' in p.skip: + extra_networks.activate(p) sd_models_compile.check_deepcache(enable=True) output = shared.sd_model(**hires_args) # pylint: disable=not-callable if isinstance(output, dict): output = SimpleNamespace(**output) - shared.history.add(output.images, info=processing.create_infotext(p), ops=p.ops) + if hasattr(output, 'images'): + shared.history.add(output.images, info=processing.create_infotext(p), ops=p.ops) sd_models_compile.check_deepcache(enable=False) sd_models_compile.openvino_post_compile(op="base") except AssertionError as e: @@ -263,8 +270,7 @@ def process_refine(p: processing.StableDiffusionProcessing, output): if shared.state.interrupted or shared.state.skipped: shared.sd_model = orig_pipeline return output - if shared.opts.diffusers_offload_mode == "balanced": - shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) if shared.opts.diffusers_move_refiner: sd_models.move_model(shared.sd_refiner, devices.device) if hasattr(shared.sd_refiner, 'unet'): @@ -313,7 +319,8 @@ def process_refine(p: processing.StableDiffusionProcessing, output): output = shared.sd_refiner(**refiner_args) # pylint: disable=not-callable if isinstance(output, dict): output = SimpleNamespace(**output) - shared.history.add(output.images, info=processing.create_infotext(p), ops=p.ops) + if hasattr(output, 'images'): + shared.history.add(output.images, info=processing.create_infotext(p), ops=p.ops) sd_models_compile.openvino_post_compile(op="refiner") except AssertionError as e: shared.log.info(e) @@ -323,13 +330,6 @@ def process_refine(p: processing.StableDiffusionProcessing, output): errors.display(e, 'Processing') modelstats.analyze() - """ # TODO decode using refiner - if not shared.state.interrupted and not shared.state.skipped: - refiner_images = processing_vae.vae_decode(latents=refiner_output.images, model=shared.sd_refiner, full_quality=True, width=max(p.width, p.hr_upscale_to_x), height=max(p.height, p.hr_upscale_to_y)) - for refiner_image in refiner_images: - results.append(refiner_image) - """ - if shared.opts.diffusers_offload_mode == "balanced": shared.sd_refiner = sd_models.apply_balanced_offload(shared.sd_refiner) elif shared.opts.diffusers_move_refiner: @@ -343,35 +343,54 @@ def process_refine(p: processing.StableDiffusionProcessing, output): def process_decode(p: processing.StableDiffusionProcessing, output): + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, exclude=['vae']) if output is not None: if not hasattr(output, 'images') and hasattr(output, 'frames'): shared.log.debug(f'Generated: frames={len(output.frames[0])}') output.images = output.frames[0] + if output.images is not None and len(output.images) > 0 and isinstance(output.images[0], Image.Image): + return output.images model = shared.sd_model if not is_refiner_enabled(p) else shared.sd_refiner if not hasattr(model, 'vae'): if hasattr(model, 'pipe') and hasattr(model.pipe, 'vae'): model = model.pipe - if hasattr(model, "vae") and output.images is not None and len(output.images) > 0: + if (hasattr(model, "vae") or hasattr(model, "vqgan")) and output.images is not None and len(output.images) > 0: if p.hr_resize_mode > 0 and (p.hr_upscaler != 'None' or p.hr_resize_mode == 5): width = max(getattr(p, 'width', 0), getattr(p, 'hr_upscale_to_x', 0)) height = max(getattr(p, 'height', 0), getattr(p, 'hr_upscale_to_y', 0)) else: width = getattr(p, 'width', 0) height = getattr(p, 'height', 0) - results = processing_vae.vae_decode( - latents = output.images, - model = model, - full_quality = p.full_quality, - width = width, - height = height, - ) + frames = p.task_args.get('num_frames', None) + if isinstance(output.images, list): + results = [] + for i in range(len(output.images)): + result_batch = processing_vae.vae_decode( + latents = output.images[i], + model = model, + full_quality = p.full_quality, + width = width, + height = height, + frames = frames, + ) + for result in list(result_batch): + results.append(result) + else: + results = processing_vae.vae_decode( + latents = output.images, + model = model, + full_quality = p.full_quality, + width = width, + height = height, + frames = frames, + ) elif hasattr(output, 'images'): results = output.images else: - shared.log.warning('Processing returned no results') + shared.log.warning('Processing: no results') results = [] else: - shared.log.warning('Processing returned no results') + shared.log.warning('Processing: no results') results = [] return results @@ -425,9 +444,11 @@ def process_diffusers(p: processing.StableDiffusionProcessing): if p.negative_prompts is None or len(p.negative_prompts) == 0: p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size] - sd_models.move_model(shared.sd_model, devices.device) sd_models_compile.openvino_recompile_model(p, hires=False, refiner=False) # recompile if a parameter changes + if hasattr(p, 'dummy'): + images = [Image.new(mode='RGB', size=(p.width, p.height))] + return images if 'base' not in p.skip: output = process_base(p) else: @@ -450,10 +471,16 @@ def process_diffusers(p: processing.StableDiffusionProcessing): shared.sd_model = orig_pipeline return results - results = process_decode(p, output) + extra_networks.deactivate(p) + timer.process.add('lora', networks.timer.total) + networks.timer.clear(complete=True) + results = process_decode(p, output) timer.process.record('decode') + shared.sd_model = orig_pipeline + # shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) + if p.state == '': global last_p # pylint: disable=global-statement last_p = p diff --git a/modules/processing_helpers.py b/modules/processing_helpers.py index ec7fbf048..304e2c211 100644 --- a/modules/processing_helpers.py +++ b/modules/processing_helpers.py @@ -1,4 +1,5 @@ import os +import time import math import random import warnings @@ -9,7 +10,7 @@ from PIL import Image from skimage import exposure from blendmodes.blend import blendLayers, BlendType -from modules import shared, devices, images, sd_models, sd_samplers, sd_hijack_hypertile, processing_vae +from modules import shared, devices, images, sd_models, sd_samplers, sd_hijack_hypertile, processing_vae, timer debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None @@ -158,7 +159,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see # enables the generation of additional tensors with noise that the sampler will use during its processing. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # produce the same images as with two batches [100], [101]. - if p is not None and p.sampler is not None and (len(seeds) > 1 and shared.opts.enable_batch_seeds or eta_noise_seed_delta > 0): + if p is not None and p.sampler is not None and ((len(seeds) > 1 and shared.opts.enable_batch_seeds) or (eta_noise_seed_delta > 0)): sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] else: sampler_noises = None @@ -352,6 +353,7 @@ def diffusers_image_conditioning(_source_image, latent_image, _image_mask=None): def validate_sample(tensor): + t0 = time.time() if not isinstance(tensor, np.ndarray) and not isinstance(tensor, torch.Tensor): return tensor dtype = tensor.dtype @@ -366,17 +368,18 @@ def validate_sample(tensor): sample = 255.0 * np.moveaxis(sample, 0, 2) if not shared.native else 255.0 * sample with warnings.catch_warnings(record=True) as w: cast = sample.astype(np.uint8) - minimum, maximum, mean = np.min(cast), np.max(cast), np.mean(cast) - if len(w) > 0 or minimum == maximum: + if len(w) > 0: nans = np.isnan(sample).sum() cast = np.nan_to_num(sample) cast = cast.astype(np.uint8) vae = shared.sd_model.vae.dtype if hasattr(shared.sd_model, 'vae') else None upcast = getattr(shared.sd_model.vae.config, 'force_upcast', None) if hasattr(shared.sd_model, 'vae') and hasattr(shared.sd_model.vae, 'config') else None - shared.log.error(f'Decode: sample={sample.shape} invalid={nans} mean={mean} dtype={dtype} vae={vae} upcast={upcast} failed to validate') + shared.log.error(f'Decode: sample={sample.shape} invalid={nans} dtype={dtype} vae={vae} upcast={upcast} failed to validate') if upcast is not None and not upcast: setattr(shared.sd_model.vae.config, 'force_upcast', True) # noqa: B010 shared.log.warning('Decode: upcast=True set, retry operation') + t1 = time.time() + timer.process.add('validate', t1 - t0) return cast @@ -411,7 +414,7 @@ def resize_hires(p, latents): # input=latents output=pil if not latent_upscaler if latent_upscaler is not None: return torch.nn.functional.interpolate(latents, size=(p.hr_upscale_to_y // 8, p.hr_upscale_to_x // 8), mode=latent_upscaler["mode"], antialias=latent_upscaler["antialias"]) first_pass_images = processing_vae.vae_decode(latents=latents, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.width, height=p.height) - if p.hr_upscale_to_x == 0 or p.hr_upscale_to_y == 0 and hasattr(p, 'init_hr'): + if p.hr_upscale_to_x == 0 or (p.hr_upscale_to_y == 0 and hasattr(p, 'init_hr')): shared.log.error('Hires: missing upscaling dimensions') return first_pass_images resized_images = [] @@ -561,7 +564,9 @@ def save_intermediate(p, latents, suffix): def update_sampler(p, sd_model, second_pass=False): sampler_selection = p.hr_sampler_name if second_pass else p.sampler_name if hasattr(sd_model, 'scheduler'): - if sampler_selection is None or sampler_selection == 'None': + if sampler_selection == 'None': + return + if sampler_selection is None: sampler = sd_samplers.all_samplers_map.get("UniPC") else: sampler = sd_samplers.all_samplers_map.get(sampler_selection, None) diff --git a/modules/processing_info.py b/modules/processing_info.py index 714ebf35f..e0fca12ae 100644 --- a/modules/processing_info.py +++ b/modules/processing_info.py @@ -140,6 +140,7 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No if sd_hijack is not None and hasattr(sd_hijack.model_hijack, 'embedding_db') and len(sd_hijack.model_hijack.embedding_db.embeddings_used) > 0: # this is for original hijaacked models only, diffusers are handled separately args["Embeddings"] = ', '.join(sd_hijack.model_hijack.embedding_db.embeddings_used) # samplers + if getattr(p, 'sampler_name', None) is not None: args["Sampler eta delta"] = shared.opts.eta_noise_seed_delta if shared.opts.eta_noise_seed_delta != 0 and sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p) else None args["Sampler eta multiplier"] = p.initial_noise_multiplier if getattr(p, 'initial_noise_multiplier', 1.0) != 1.0 else None diff --git a/modules/processing_vae.py b/modules/processing_vae.py index 3c0357c81..95b83daa4 100644 --- a/modules/processing_vae.py +++ b/modules/processing_vae.py @@ -33,6 +33,62 @@ def create_latents(image, p, dtype=None, device=None): return latents +def full_vqgan_decode(latents, model): + t0 = time.time() + if model is None or not hasattr(model, 'vqgan'): + shared.log.error('VQGAN not found in model') + return [] + if debug: + devices.torch_gc(force=True) + shared.mem_mon.reset() + + base_device = None + if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False): + base_device = sd_models.move_base(model, devices.cpu) + + if shared.opts.diffusers_offload_mode == "balanced": + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) + elif shared.opts.diffusers_offload_mode != "sequential": + sd_models.move_model(model.vqgan, devices.device) + + latents = latents.to(devices.device, dtype=model.vqgan.dtype) + + #normalize latents + scaling_factor = model.vqgan.config.get("scale_factor", None) + if scaling_factor: + latents = latents * scaling_factor + + vae_name = os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0] if sd_vae.loaded_vae_file is not None else "default" + vae_stats = f'name="{vae_name}" dtype={model.vqgan.dtype} device={model.vqgan.device}' + latents_stats = f'shape={latents.shape} dtype={latents.dtype} device={latents.device}' + stats = f'vae {vae_stats} latents {latents_stats}' + + log_debug(f'VAE config: {model.vqgan.config}') + try: + decoded = model.vqgan.decode(latents).sample.clamp(0, 1) + except Exception as e: + shared.log.error(f'VAE decode: {stats} {e}') + errors.display(e, 'VAE decode') + decoded = [] + + # delete vae after OpenVINO compile + if 'VAE' in shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx" and shared.compiled_model_state.first_pass_vae: + shared.compiled_model_state.first_pass_vae = False + if not shared.opts.openvino_disable_memory_cleanup and hasattr(shared.sd_model, "vqgan"): + model.vqgan.apply(sd_models.convert_to_faketensors) + devices.torch_gc(force=True) + + if shared.opts.diffusers_offload_mode == "balanced": + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) + elif shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and base_device is not None: + sd_models.move_base(model, base_device) + t1 = time.time() + if debug: + log_debug(f'VAE memory: {shared.mem_mon.read()}') + shared.log.debug(f'VAE decode: {stats} time={round(t1-t0, 3)}') + return decoded + + def full_vae_decode(latents, model): t0 = time.time() if not hasattr(model, 'vae') and hasattr(model, 'pipe'): @@ -48,8 +104,6 @@ def full_vae_decode(latents, model): if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False): base_device = sd_models.move_base(model, devices.cpu) - if shared.opts.diffusers_offload_mode == "balanced": - shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) elif shared.opts.diffusers_offload_mode != "sequential": sd_models.move_model(model.vae, devices.device) @@ -60,10 +114,11 @@ def full_vae_decode(latents, model): else: # manual upcast and we restore it later model.vae.orig_dtype = model.vae.dtype model.vae = model.vae.to(dtype=torch.float32) - latents = latents.to(torch.float32) latents = latents.to(devices.device) if getattr(model.vae, "post_quant_conv", None) is not None: latents = latents.to(next(iter(model.vae.post_quant_conv.parameters())).dtype) + else: + latents = latents.to(model.vae.dtype) # normalize latents latents_mean = model.vae.config.get("latents_mean", None) @@ -86,10 +141,12 @@ def full_vae_decode(latents, model): log_debug(f'VAE config: {model.vae.config}') try: - decoded = model.vae.decode(latents, return_dict=False)[0] + with devices.inference_context(): + decoded = model.vae.decode(latents, return_dict=False)[0] except Exception as e: shared.log.error(f'VAE decode: {stats} {e}') - errors.display(e, 'VAE decode') + if 'out of memory' not in str(e): + errors.display(e, 'VAE decode') decoded = [] if hasattr(model.vae, "orig_dtype"): @@ -103,8 +160,6 @@ def full_vae_decode(latents, model): model.vae.apply(sd_models.convert_to_faketensors) devices.torch_gc(force=True) - if shared.opts.diffusers_offload_mode == "balanced": - shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) elif shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and base_device is not None: sd_models.move_base(model, base_device) t1 = time.time() @@ -147,7 +202,7 @@ def taesd_vae_encode(image): return encoded -def vae_decode(latents, model, output_type='np', full_quality=True, width=None, height=None): +def vae_decode(latents, model, output_type='np', full_quality=True, width=None, height=None, frames=None): t0 = time.time() model = model or shared.sd_model if not hasattr(model, 'vae') and hasattr(model, 'pipe'): @@ -161,11 +216,15 @@ def vae_decode(latents, model, output_type='np', full_quality=True, width=None, return [] if shared.state.interrupted or shared.state.skipped: return [] - if not hasattr(model, 'vae'): + if not hasattr(model, 'vae') and not hasattr(model, 'vqgan'): shared.log.error('VAE not found in model') return [] - if hasattr(model, "_unpack_latents") and hasattr(model, "vae_scale_factor") and width is not None and height is not None: # FLUX + if hasattr(model, '_unpack_latents') and hasattr(model, 'transformer_spatial_patch_size') and frames is not None: # LTX + latent_num_frames = (frames - 1) // model.vae_temporal_compression_ratio + 1 + latents = model._unpack_latents(latents.unsqueeze(0), latent_num_frames, height // 32, width // 32, model.transformer_spatial_patch_size, model.transformer_temporal_patch_size) # pylint: disable=protected-access + latents = model._denormalize_latents(latents, model.vae.latents_mean, model.vae.latents_std, model.vae.config.scaling_factor) # pylint: disable=protected-access + if hasattr(model, '_unpack_latents') and hasattr(model, "vae_scale_factor") and width is not None and height is not None: # FLUX latents = model._unpack_latents(latents, height, width, model.vae_scale_factor) # pylint: disable=protected-access if len(latents.shape) == 3: # lost a batch dim in hires latents = latents.unsqueeze(0) @@ -176,12 +235,20 @@ def vae_decode(latents, model, output_type='np', full_quality=True, width=None, decoded = latents.float().cpu().numpy() elif full_quality and hasattr(model, "vae"): decoded = full_vae_decode(latents=latents, model=model) + elif hasattr(model, "vqgan"): + decoded = full_vqgan_decode(latents=latents, model=model) else: decoded = taesd_vae_decode(latents=latents) if torch.is_tensor(decoded): - if hasattr(model, 'image_processor'): + if hasattr(model, 'video_processor'): + imgs = model.video_processor.postprocess_video(decoded, output_type='pil') + elif hasattr(model, 'image_processor'): imgs = model.image_processor.postprocess(decoded, output_type=output_type) + elif hasattr(model, "vqgan"): + imgs = decoded.permute(0, 2, 3, 1).cpu().float().numpy() + if output_type == "pil": + imgs = model.numpy_to_pil(imgs) else: import diffusers model.image_processor = diffusers.image_processor.VaeImageProcessor() diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 3a1288097..1eb7a77ee 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -147,7 +147,7 @@ def __default__(self, data, children, meta): def get_schedule(prompt): try: tree = schedule_parser.parse(prompt) - except lark.exceptions.LarkError: + except Exception: return [[steps, prompt]] return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)] diff --git a/modules/prompt_parser_diffusers.py b/modules/prompt_parser_diffusers.py index 234272907..4e31c747a 100644 --- a/modules/prompt_parser_diffusers.py +++ b/modules/prompt_parser_diffusers.py @@ -16,6 +16,7 @@ token_dict = None # used by helper get_tokens token_type = None # used by helper get_tokens cache = OrderedDict() +last_attention = None embedder = None @@ -38,8 +39,6 @@ def prepare_model(pipe = None): pipe = pipe.pipe if not hasattr(pipe, "text_encoder"): return None - if shared.opts.diffusers_offload_mode == "balanced": - pipe = sd_models.apply_balanced_offload(pipe) elif hasattr(pipe, "maybe_free_model_hooks"): pipe.maybe_free_model_hooks() devices.torch_gc() @@ -52,7 +51,7 @@ def __init__(self, prompts, negative_prompts, steps, clip_skip, p): self.prompts = prompts self.negative_prompts = negative_prompts self.batchsize = len(self.prompts) - self.attention = None + self.attention = last_attention self.allsame = self.compare_prompts() # collapses batched prompts to single prompt if possible self.steps = steps self.clip_skip = clip_skip @@ -64,6 +63,8 @@ def __init__(self, prompts, negative_prompts, steps, clip_skip, p): self.positive_schedule = None self.negative_schedule = None self.scheduled_prompt = False + if hasattr(p, 'dummy'): + return earlyout = self.checkcache(p) if earlyout: return @@ -93,7 +94,7 @@ def flatten(xss): return [x for xs in xss for x in xs] # unpack EN data in case of TE LoRA - en_data = p.extra_network_data + en_data = p.network_data en_data = [idx.items for item in en_data.values() for idx in item] effective_batch = 1 if self.allsame else self.batchsize key = str([self.prompts, self.negative_prompts, effective_batch, self.clip_skip, self.steps, en_data]) @@ -113,6 +114,7 @@ def flatten(xss): debug(f"Prompt cache: add={key}") while len(cache) > int(shared.opts.sd_textencoder_cache_size): cache.popitem(last=False) + return True if item: self.__dict__.update(cache[key]) cache.move_to_end(key) @@ -161,8 +163,10 @@ def extend_embeds(self, batchidx, idx): # Extends scheduled prompt via index self.negative_pooleds[batchidx].append(self.negative_pooleds[batchidx][idx]) def encode(self, pipe, positive_prompt, negative_prompt, batchidx): + global last_attention # pylint: disable=global-statement self.attention = shared.opts.prompt_attention - if self.attention == "xhinker" or 'Flux' in pipe.__class__.__name__: + last_attention = self.attention + if self.attention == "xhinker": prompt_embed, positive_pooled, negative_embed, negative_pooled = get_xhinker_text_embeddings(pipe, positive_prompt, negative_prompt, self.clip_skip) else: prompt_embed, positive_pooled, negative_embed, negative_pooled = get_weighted_text_embeddings(pipe, positive_prompt, negative_prompt, self.clip_skip) @@ -178,7 +182,6 @@ def encode(self, pipe, positive_prompt, negative_prompt, batchidx): if debug_enabled: get_tokens(pipe, 'positive', positive_prompt) get_tokens(pipe, 'negative', negative_prompt) - pipe = prepare_model() def __call__(self, key, step=0): batch = getattr(self, key) @@ -194,8 +197,6 @@ def __call__(self, key, step=0): def compel_hijack(self, token_ids: torch.Tensor, attention_mask: typing.Optional[torch.Tensor] = None) -> torch.Tensor: - if not devices.same_device(self.text_encoder.device, devices.device): - sd_models.move_model(self.text_encoder, devices.device) needs_hidden_states = self.returned_embeddings_type != 1 text_encoder_output = self.text_encoder(token_ids, attention_mask, output_hidden_states=needs_hidden_states, return_dict=True) @@ -372,25 +373,31 @@ def prepare_embedding_providers(pipe, clip_skip) -> list[EmbeddingsProvider]: embedding_type = -(clip_skip + 1) else: embedding_type = clip_skip + embedding_args = { + 'truncate': False, + 'returned_embeddings_type': embedding_type, + 'device': device, + 'dtype_for_device_getter': lambda device: devices.dtype, + } if getattr(pipe, "prior_pipe", None) is not None and getattr(pipe.prior_pipe, "tokenizer", None) is not None and getattr(pipe.prior_pipe, "text_encoder", None) is not None: - provider = EmbeddingsProvider(padding_attention_mask_value=0, tokenizer=pipe.prior_pipe.tokenizer, text_encoder=pipe.prior_pipe.text_encoder, truncate=False, returned_embeddings_type=embedding_type, device=device) + provider = EmbeddingsProvider(padding_attention_mask_value=0, tokenizer=pipe.prior_pipe.tokenizer, text_encoder=pipe.prior_pipe.text_encoder, **embedding_args) embeddings_providers.append(provider) - no_mask_provider = EmbeddingsProvider(padding_attention_mask_value=1 if "sote" in pipe.sd_checkpoint_info.name.lower() else 0, tokenizer=pipe.prior_pipe.tokenizer, text_encoder=pipe.prior_pipe.text_encoder, truncate=False, returned_embeddings_type=embedding_type, device=device) + no_mask_provider = EmbeddingsProvider(padding_attention_mask_value=1 if "sote" in pipe.sd_checkpoint_info.name.lower() else 0, tokenizer=pipe.prior_pipe.tokenizer, text_encoder=pipe.prior_pipe.text_encoder, **embedding_args) embeddings_providers.append(no_mask_provider) elif getattr(pipe, "tokenizer", None) is not None and getattr(pipe, "text_encoder", None) is not None: - if not devices.same_device(pipe.text_encoder.device, devices.device): - sd_models.move_model(pipe.text_encoder, devices.device) - provider = EmbeddingsProvider(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, truncate=False, returned_embeddings_type=embedding_type, device=device) + if pipe.text_encoder.__class__.__name__.startswith('CLIP'): + sd_models.move_model(pipe.text_encoder, devices.device, force=True) + provider = EmbeddingsProvider(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, **embedding_args) embeddings_providers.append(provider) if getattr(pipe, "tokenizer_2", None) is not None and getattr(pipe, "text_encoder_2", None) is not None: - if not devices.same_device(pipe.text_encoder_2.device, devices.device): - sd_models.move_model(pipe.text_encoder_2, devices.device) - provider = EmbeddingsProvider(tokenizer=pipe.tokenizer_2, text_encoder=pipe.text_encoder_2, truncate=False, returned_embeddings_type=embedding_type, device=device) + if pipe.text_encoder_2.__class__.__name__.startswith('CLIP'): + sd_models.move_model(pipe.text_encoder_2, devices.device, force=True) + provider = EmbeddingsProvider(tokenizer=pipe.tokenizer_2, text_encoder=pipe.text_encoder_2, **embedding_args) embeddings_providers.append(provider) if getattr(pipe, "tokenizer_3", None) is not None and getattr(pipe, "text_encoder_3", None) is not None: - if not devices.same_device(pipe.text_encoder_3.device, devices.device): - sd_models.move_model(pipe.text_encoder_3, devices.device) - provider = EmbeddingsProvider(tokenizer=pipe.tokenizer_3, text_encoder=pipe.text_encoder_3, truncate=False, returned_embeddings_type=embedding_type, device=device) + if pipe.text_encoder_3.__class__.__name__.startswith('CLIP'): + sd_models.move_model(pipe.text_encoder_3, devices.device, force=True) + provider = EmbeddingsProvider(tokenizer=pipe.tokenizer_3, text_encoder=pipe.text_encoder_3, **embedding_args) embeddings_providers.append(provider) return embeddings_providers @@ -583,15 +590,15 @@ def get_xhinker_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", cl te1_device, te2_device, te3_device = None, None, None if hasattr(pipe, "text_encoder") and pipe.text_encoder.device != devices.device: te1_device = pipe.text_encoder.device - sd_models.move_model(pipe.text_encoder, devices.device) + sd_models.move_model(pipe.text_encoder, devices.device, force=True) if hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2.device != devices.device: te2_device = pipe.text_encoder_2.device - sd_models.move_model(pipe.text_encoder_2, devices.device) + sd_models.move_model(pipe.text_encoder_2, devices.device, force=True) if hasattr(pipe, "text_encoder_3") and pipe.text_encoder_3.device != devices.device: te3_device = pipe.text_encoder_3.device - sd_models.move_model(pipe.text_encoder_3, devices.device) + sd_models.move_model(pipe.text_encoder_3, devices.device, force=True) - if is_sd3: + if 'StableDiffusion3' in pipe.__class__.__name__: prompt_embed, negative_embed, positive_pooled, negative_pooled = get_weighted_text_embeddings_sd3(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, use_t5_encoder=bool(pipe.text_encoder_3)) elif 'Flux' in pipe.__class__.__name__: prompt_embed, positive_pooled = get_weighted_text_embeddings_flux1(pipe=pipe, prompt=prompt, prompt2=prompt_2, device=devices.device) @@ -601,10 +608,10 @@ def get_xhinker_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", cl prompt_embed, negative_embed = get_weighted_text_embeddings_sd15(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, clip_skip=clip_skip) if te1_device is not None: - sd_models.move_model(pipe.text_encoder, te1_device) + sd_models.move_model(pipe.text_encoder, te1_device, force=True) if te2_device is not None: - sd_models.move_model(pipe.text_encoder_2, te1_device) + sd_models.move_model(pipe.text_encoder_2, te1_device, force=True) if te3_device is not None: - sd_models.move_model(pipe.text_encoder_3, te1_device) + sd_models.move_model(pipe.text_encoder_3, te1_device, force=True) return prompt_embed, positive_pooled, negative_embed, negative_pooled diff --git a/modules/pulid/eva_clip/hf_model.py b/modules/pulid/eva_clip/hf_model.py index c4b9fd85b..d148bbff2 100644 --- a/modules/pulid/eva_clip/hf_model.py +++ b/modules/pulid/eva_clip/hf_model.py @@ -31,7 +31,6 @@ class PretrainedConfig: def _camel2snake(s): return re.sub(r'(?DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + +class BDIA_DDIMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, #was True + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", #leading + rescale_betas_zero_snr: bool = False, + gamma: float = 1.0, + + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas #may have to add something for last step + + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.next_sample = [] + self.BDIA = False + + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + debug: bool = False, + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. + + Args: + model_output (torch.Tensor): Direct output from learned diffusion model + timestep (int): Current discrete timestep in the diffusion chain + sample (torch.Tensor): Current instance of sample created by diffusion process + eta (float): Weight of noise for added noise in diffusion step + use_clipped_model_output (bool): Whether to use clipped model output + generator (torch.Generator, optional): Random number generator + variance_noise (torch.Tensor, optional): Pre-generated noise for variance + return_dict (bool): Whether to return as DDIMSchedulerOutput or tuple + debug (bool): Whether to print debug information + """ + if self.num_inference_steps is None: + raise ValueError("Number of inference steps is 'None', run 'set_timesteps' first") + + # Calculate timesteps + step_size = self.config.num_train_timesteps // self.num_inference_steps + prev_timestep = timestep - step_size + next_timestep = timestep + step_size + + if debug: + print("\n=== Timestep Information ===") + print(f"Current timestep: {timestep}") + print(f"Previous timestep: {prev_timestep}") + print(f"Next timestep: {next_timestep}") + print(f"Step size: {step_size}") + + # Pre-compute alpha and variance values + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** 0.5 + + # Compute required values + alpha_i = alpha_prod_t ** 0.5 + alpha_i_minus_1 = alpha_prod_t_prev ** 0.5 + sigma_i = (1 - alpha_prod_t) ** 0.5 + sigma_i_minus_1 = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 + + if debug: + print("\n=== Alpha Values ===") + print(f"alpha_i: {alpha_i}") + print(f"alpha_i_minus_1: {alpha_i_minus_1}") + print(f"sigma_i: {sigma_i}") + print(f"sigma_i_minus_1: {sigma_i_minus_1}") + + # Predict original sample based on prediction type + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - sigma_i * model_output) / alpha_i + pred_epsilon = model_output + if debug: + print("\nPrediction type: epsilon") + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_i * pred_original_sample) / sigma_i + if debug: + print("\nPrediction type: sample") + elif self.config.prediction_type == "v_prediction": + pred_original_sample = alpha_i * sample - sigma_i * model_output + pred_epsilon = alpha_i * model_output + sigma_i * sample + if debug: + print("\nPrediction type: v_prediction") + else: + raise ValueError( + f"prediction_type {self.config.prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + # Apply thresholding or clipping if configured + if self.config.thresholding: + if debug: + print("\nApplying thresholding") + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + if debug: + print("\nApplying clipping") + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # Recompute pred_epsilon if using clipped model output + if use_clipped_model_output: + if debug: + print("\nUsing clipped model output") + pred_epsilon = (sample - alpha_i * pred_original_sample) / sigma_i + + # Compute DDIM step + ddim_step = alpha_i_minus_1 * pred_original_sample + sigma_i_minus_1 * pred_epsilon + + # Handle initial DDIM step or BDIA steps + if len(self.next_sample) == 0: + if debug: + print("\nFirst iteration (DDIM)") + self.update_next_sample_BDIA(sample) + self.update_next_sample_BDIA(ddim_step) + else: + if debug: + print("\nBDIA step") + # BDIA implementation + alpha_prod_t_next = self.alphas_cumprod[next_timestep] + alpha_i_plus_1 = alpha_prod_t_next ** 0.5 + sigma_i_plus_1 = (1 - alpha_prod_t_next) ** 0.5 + + if debug: + print(f"alpha_i_plus_1: {alpha_i_plus_1}") + print(f"sigma_i_plus_1: {sigma_i_plus_1}") + + a = alpha_i_plus_1 * pred_original_sample + sigma_i_plus_1 * pred_epsilon + bdia_step = ( + self.config.gamma * self.next_sample[-2] + + ddim_step - + (self.config.gamma * a) + ) + self.update_next_sample_BDIA(bdia_step) + + prev_sample = self.next_sample[-1] + + # Apply variance noise if eta > 0 + if eta > 0: + if debug: + print(f"\nApplying variance noise with eta: {eta}") + + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Use either `generator` or `variance_noise`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype + ) + prev_sample = prev_sample + std_dev_t * variance_noise + + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def update_next_sample_BDIA(self, new_value): + self.next_sample.append(new_value.clone()) + + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/modules/schedulers/scheduler_dpm_flowmatch.py b/modules/schedulers/scheduler_dpm_flowmatch.py index 1afe54498..69452aca9 100644 --- a/modules/schedulers/scheduler_dpm_flowmatch.py +++ b/modules/schedulers/scheduler_dpm_flowmatch.py @@ -22,7 +22,8 @@ def __init__(self, x, t0, t1, seed=None, **kwargs): t0, t1, self.sign = self.sort(t0, t1) w0 = kwargs.get("w0", torch.zeros_like(x)) if seed is None: - seed = torch.randint(0, 2**63 - 1, []).item() + seed = [torch.randint(0, 2**63 - 1, []).item()] + seed = [s.initial_seed() if isinstance(s, torch.Generator) else s for s in seed] self.batched = True try: assert len(seed) == x.shape[0] diff --git a/modules/schedulers/scheduler_tcd.py b/modules/schedulers/scheduler_tcd.py index bff12ab5a..9b2d4d35a 100644 --- a/modules/schedulers/scheduler_tcd.py +++ b/modules/schedulers/scheduler_tcd.py @@ -447,7 +447,6 @@ def set_timesteps( init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps) t_start = max(self.num_inference_steps - init_timestep, 0) timesteps = timesteps[t_start * self.order :] - # TODO: also reset self.num_inference_steps? else: # 2.2 Create the "standard" TCD inference timestep schedule. if num_inference_steps > self.config.num_train_timesteps: diff --git a/modules/schedulers/scheduler_ufogen.py b/modules/schedulers/scheduler_ufogen.py new file mode 100644 index 000000000..f4d8aee97 --- /dev/null +++ b/modules/schedulers/scheduler_ufogen.py @@ -0,0 +1,523 @@ +# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim and https://github.com/xuyanwu/SIDDMs-UFOGen + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UFOGen +class UFOGenSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class UFOGenScheduler(SchedulerMixin, ConfigMixin): + """ + `UFOGenScheduler` implements multistep and onestep sampling for a UFOGen model, introduced in + [UFOGen: You Forward Once Large Scale Text-to-Image Generation via Diffusion GANs](https://arxiv.org/abs/2311.09257) + by Yanwu Xu, Yang Zhao, Zhisheng Xiao, and Tingbo Hou. UFOGen is a varianet of the denoising diffusion GAN (DDGAN) + model designed for one-step sampling. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + denoising_step_size (`int`, defaults to 250): + The denoising step size parameter from the UFOGen paper. The number of steps used for training is roughly + `math.ceil(num_train_timesteps / denoising_step_size)`. + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, + denoising_step_size: int = 250, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = torch.linspace(-6, 6, num_train_timesteps) + self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.custom_timesteps = False + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed, + `num_inference_steps` must be `None`. + + """ + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`custom_timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.custom_timesteps = True + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + self.custom_timesteps = False + + if num_inference_steps == 1: + # Set the timestep schedule to num_train_timesteps - 1 rather than 0 + # (that is, the one-step timestep schedule is always trailing rather than leading or linspace) + timesteps = np.array([self.config.num_train_timesteps - 1], dtype=np.int64) + else: + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[UFOGenSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ufogen.UFOGenSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddpm.UFOGenSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ufogen.UFOGenSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + # 0. Resolve timesteps + t = timestep + prev_t = self.previous_timestep(t) + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + # beta_prod_t_prev = 1 - alpha_prod_t_prev + # current_alpha_t = alpha_prod_t / alpha_prod_t_prev + # current_beta_t = 1 - current_alpha_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for UFOGenScheduler." + ) + + # 3. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. Single-step or multi-step sampling + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + if t != self.timesteps[-1]: + device = model_output.device + noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype) + sqrt_alpha_prod_t_prev = alpha_prod_t_prev**0.5 + sqrt_one_minus_alpha_prod_t_prev = (1 - alpha_prod_t_prev) ** 0.5 + pred_prev_sample = sqrt_alpha_prod_t_prev * pred_original_sample + sqrt_one_minus_alpha_prod_t_prev * noise + else: + # Simply return the pred_original_sample. If `prediction_type == "sample"`, this is equivalent to returning + # the output of the GAN generator U-Net on the initial noisy latents x_T ~ N(0, I). + pred_prev_sample = pred_original_sample + + if not return_dict: + return (pred_prev_sample,) + + return UFOGenSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep + def previous_timestep(self, timestep): + if self.custom_timesteps: + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + else: + num_inference_steps = ( + self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + ) + prev_t = timestep - self.config.num_train_timesteps // num_inference_steps + + return prev_t \ No newline at end of file diff --git a/modules/schedulers/scheduler_vdm.py b/modules/schedulers/scheduler_vdm.py index 543b29ff3..492c30a0c 100644 --- a/modules/schedulers/scheduler_vdm.py +++ b/modules/schedulers/scheduler_vdm.py @@ -147,7 +147,7 @@ def __init__( self.timesteps = torch.from_numpy(self.get_timesteps(len(self))) if num_train_timesteps: alphas_cumprod = self.alphas_cumprod(torch.flip(self.timesteps, dims=(0,))) - alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] # TODO: Might not be exact + alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] self.alphas = torch.cat([alphas_cumprod[:1], alphas]) self.betas = 1 - self.alphas diff --git a/modules/scripts.py b/modules/scripts.py index cf2cf25b9..9c410a3b6 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -288,6 +288,7 @@ def register_scripts_from_module(module, scriptfile): current_basedir = paths.script_path t.record(os.path.basename(scriptfile.basedir) if scriptfile.basedir != paths.script_path else scriptfile.filename) sys.path = syspath + global scripts_txt2img, scripts_img2img, scripts_control, scripts_postproc # pylint: disable=global-statement scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() @@ -329,6 +330,7 @@ def __init__(self): self.scripts = [] self.selectable_scripts = [] self.alwayson_scripts = [] + self.auto_processing_scripts = [] self.titles = [] self.infotext_fields = [] self.paste_field_names = [] @@ -336,6 +338,31 @@ def __init__(self): self.is_img2img = False self.inputs = [None] + def add_script(self, script_class, path, is_img2img, is_control): + try: + script = script_class() + script.filename = path + script.is_txt2img = not is_img2img + script.is_img2img = is_img2img + if is_control: # this is messy but show is a legacy function that is not aware of control tab + v1 = script.show(script.is_txt2img) + v2 = script.show(script.is_img2img) + if v1 == AlwaysVisible or v2 == AlwaysVisible: + visibility = AlwaysVisible + else: + visibility = v1 or v2 + else: + visibility = script.show(script.is_img2img) + if visibility == AlwaysVisible: + self.scripts.append(script) + self.alwayson_scripts.append(script) + script.alwayson = True + elif visibility: + self.scripts.append(script) + self.selectable_scripts.append(script) + except Exception as e: + errors.log.error(f'Script initialize: {path} {e}') + def initialize_scripts(self, is_img2img=False, is_control=False): from modules import scripts_auto_postprocessing @@ -350,34 +377,14 @@ def initialize_scripts(self, is_img2img=False, is_control=False): self.scripts.clear() self.alwayson_scripts.clear() self.selectable_scripts.clear() - auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() + self.auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() - all_scripts = auto_processing_scripts + scripts_data - sorted_scripts = sorted(all_scripts, key=lambda x: x.script_class().title().lower()) + sorted_scripts = sorted(scripts_data, key=lambda x: x.script_class().title().lower()) for script_class, path, _basedir, _script_module in sorted_scripts: - try: - script = script_class() - script.filename = path - script.is_txt2img = not is_img2img - script.is_img2img = is_img2img - if is_control: # this is messy but show is a legacy function that is not aware of control tab - v1 = script.show(script.is_txt2img) - v2 = script.show(script.is_img2img) - if v1 == AlwaysVisible or v2 == AlwaysVisible: - visibility = AlwaysVisible - else: - visibility = v1 or v2 - else: - visibility = script.show(script.is_img2img) - if visibility == AlwaysVisible: - self.scripts.append(script) - self.alwayson_scripts.append(script) - script.alwayson = True - elif visibility: - self.scripts.append(script) - self.selectable_scripts.append(script) - except Exception as e: - errors.log.error(f'Script initialize: {path} {e}') + self.add_script(script_class, path, is_img2img, is_control) + sorted_scripts = sorted(self.auto_processing_scripts, key=lambda x: x.script_class().title().lower()) + for script_class, path, _basedir, _script_module in sorted_scripts: + self.add_script(script_class, path, is_img2img, is_control) def prepare_ui(self): self.inputs = [None] diff --git a/modules/sd_checkpoint.py b/modules/sd_checkpoint.py index afc5842e4..6ab396329 100644 --- a/modules/sd_checkpoint.py +++ b/modules/sd_checkpoint.py @@ -123,13 +123,17 @@ def list_models(): checkpoint_aliases.clear() ext_filter = [".safetensors"] if shared.opts.sd_disable_ckpt or shared.native else [".ckpt", ".safetensors"] model_list = list(modelloader.load_models(model_path=model_path, model_url=None, command_path=shared.opts.ckpt_dir, ext_filter=ext_filter, download_name=None, ext_blacklist=[".vae.ckpt", ".vae.safetensors"])) + safetensors_list = [] for filename in sorted(model_list, key=str.lower): checkpoint_info = CheckpointInfo(filename) + safetensors_list.append(checkpoint_info) if checkpoint_info.name is not None: checkpoint_info.register() + diffusers_list = [] if shared.native: for repo in modelloader.load_diffusers_models(clear=True): checkpoint_info = CheckpointInfo(repo['name'], sha=repo['hash']) + diffusers_list.append(checkpoint_info) if checkpoint_info.name is not None: checkpoint_info.register() if shared.cmd_opts.ckpt is not None: @@ -143,7 +147,7 @@ def list_models(): shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title elif shared.cmd_opts.ckpt != shared.default_sd_model_file and shared.cmd_opts.ckpt is not None: shared.log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found') - shared.log.info(f'Available Models: path="{shared.opts.ckpt_dir}" items={len(checkpoints_list)} time={time.time()-t0:.2f}') + shared.log.info(f'Available Models: items={len(checkpoints_list)} safetensors="{shared.opts.ckpt_dir}":{len(safetensors_list)} diffusers="{shared.opts.diffusers_dir}":{len(diffusers_list)} time={time.time()-t0:.2f}') checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename)) def update_model_hashes(): @@ -168,7 +172,10 @@ def update_model_hashes(): def get_closet_checkpoint_match(s: str): if s.startswith('https://huggingface.co/'): - s = s.replace('https://huggingface.co/', '') + model_name = s.replace('https://huggingface.co/', '') + checkpoint_info = CheckpointInfo(model_name) # create a virutal model info + checkpoint_info.type = 'huggingface' + return checkpoint_info if s.startswith('huggingface/'): model_name = s.replace('huggingface/', '') checkpoint_info = CheckpointInfo(model_name) # create a virutal model info @@ -185,6 +192,11 @@ def get_closet_checkpoint_match(s: str): if found and len(found) == 1: return found[0] + # absolute path + if s.endswith('.safetensors') and os.path.isfile(s): + checkpoint_info = CheckpointInfo(s) + return checkpoint_info + # reference search """ found = sorted([info for info in shared.reference_models.values() if os.path.basename(info['path']).lower().startswith(s.lower())], key=lambda x: len(x['path'])) @@ -198,8 +210,9 @@ def get_closet_checkpoint_match(s: str): if shared.opts.sd_checkpoint_autodownload and s.count('/') == 1: modelloader.hf_login() found = modelloader.find_diffuser(s, full=True) + found = [f for f in found if f == s] shared.log.info(f'HF search: model="{s}" results={found}') - if found is not None and len(found) == 1 and found[0] == s: + if found is not None and len(found) == 1: checkpoint_info = CheckpointInfo(s) checkpoint_info.type = 'huggingface' return checkpoint_info @@ -240,6 +253,8 @@ def select_checkpoint(op='model'): model_checkpoint = shared.opts.data.get('sd_model_refiner', None) else: model_checkpoint = shared.opts.sd_model_checkpoint + if len(model_checkpoint) < 3: + return None if model_checkpoint is None or model_checkpoint == 'None': return None checkpoint_info = get_closet_checkpoint_match(model_checkpoint) @@ -266,6 +281,12 @@ def select_checkpoint(op='model'): return checkpoint_info +def init_metadata(): + global sd_metadata # pylint: disable=global-statement + if sd_metadata is None: + sd_metadata = shared.readfile(sd_metadata_file, lock=True) if os.path.isfile(sd_metadata_file) else {} + + def read_metadata_from_safetensors(filename): global sd_metadata # pylint: disable=global-statement if sd_metadata is None: diff --git a/modules/sd_detect.py b/modules/sd_detect.py index 062bb32e1..fe3d325c9 100644 --- a/modules/sd_detect.py +++ b/modules/sd_detect.py @@ -71,6 +71,8 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False): guess = 'Stable Cascade' if 'pixart-sigma' in f.lower(): guess = 'PixArt-Sigma' + if 'sana' in f.lower(): + guess = 'Sana' if 'lumina-next' in f.lower(): guess = 'Lumina-Next' if 'kolors' in f.lower(): @@ -94,6 +96,17 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False): guess = 'FLUX' if size > 11000 and size < 16000: warn(f'Model detected as FLUX UNET model, but attempting to load a base model: {op}={f} size={size} MB') + # guess for diffusers + index = os.path.join(f, 'model_index.json') + if os.path.exists(index) and os.path.isfile(index): + index = shared.readfile(index, silent=True) + cls = index.get('_class_name', None) + if cls is not None: + pipeline = getattr(diffusers, cls) + if 'Flux' in pipeline.__name__: + guess = 'FLUX' + if 'StableDiffusion3' in pipeline.__name__: + guess = 'Stable Diffusion 3' # switch for specific variant if guess == 'Stable Diffusion' and 'inpaint' in f.lower(): guess = 'Stable Diffusion Inpaint' @@ -105,7 +118,7 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False): guess = 'Stable Diffusion XL Instruct' # get actual pipeline pipeline = shared_items.get_pipelines().get(guess, None) if pipeline is None else pipeline - if not quiet: + if debug_load is not None: shared.log.info(f'Autodetect {op}: detect="{guess}" class={getattr(pipeline, "__name__", None)} file="{f}" size={size}MB') t0 = time.time() keys = model_tools.get_safetensor_keys(f) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index e9ac1be92..688af0f34 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -53,7 +53,7 @@ def transformers_modeling_utils_load_pretrained_model(*args, **kwargs): def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs): # this file is always 404, prevent making request - if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json': + if (url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14') and args[0] == 'added_tokens.json': return None try: diff --git a/modules/sd_hijack_accelerate.py b/modules/sd_hijack_accelerate.py index 90eac5c4e..f8cf8983f 100644 --- a/modules/sd_hijack_accelerate.py +++ b/modules/sd_hijack_accelerate.py @@ -35,10 +35,10 @@ def hijack_set_module_tensor( with devices.inference_context(): # note: majority of time is spent on .to(old_value.dtype) if tensor_name in module._buffers: # pylint: disable=protected-access - module._buffers[tensor_name] = value.to(device, old_value.dtype, non_blocking=True) # pylint: disable=protected-access + module._buffers[tensor_name] = value.to(device, old_value.dtype) # pylint: disable=protected-access elif value is not None or not devices.same_device(torch.device(device), module._parameters[tensor_name].device): # pylint: disable=protected-access param_cls = type(module._parameters[tensor_name]) # pylint: disable=protected-access - module._parameters[tensor_name] = param_cls(value, requires_grad=old_value.requires_grad).to(device, old_value.dtype, non_blocking=True) # pylint: disable=protected-access + module._parameters[tensor_name] = param_cls(value, requires_grad=old_value.requires_grad).to(device, old_value.dtype) # pylint: disable=protected-access t1 = time.time() tensor_to_timer += (t1 - t0) @@ -63,10 +63,10 @@ def hijack_set_module_tensor_simple( old_value = getattr(module, tensor_name) with devices.inference_context(): if tensor_name in module._buffers: # pylint: disable=protected-access - module._buffers[tensor_name] = value.to(device, non_blocking=True) # pylint: disable=protected-access + module._buffers[tensor_name] = value.to(device) # pylint: disable=protected-access elif value is not None or not devices.same_device(torch.device(device), module._parameters[tensor_name].device): # pylint: disable=protected-access param_cls = type(module._parameters[tensor_name]) # pylint: disable=protected-access - module._parameters[tensor_name] = param_cls(value, requires_grad=old_value.requires_grad).to(device, non_blocking=True) # pylint: disable=protected-access + module._parameters[tensor_name] = param_cls(value, requires_grad=old_value.requires_grad).to(device) # pylint: disable=protected-access t1 = time.time() tensor_to_timer += (t1 - t0) diff --git a/modules/sd_hijack_hypertile.py b/modules/sd_hijack_hypertile.py index dbf977b8d..69c4163dc 100644 --- a/modules/sd_hijack_hypertile.py +++ b/modules/sd_hijack_hypertile.py @@ -112,7 +112,7 @@ def wrapper(*args, **kwargs): out = forward(x, *args[1:], **kwargs) return out if x.ndim == 4: # VAE - # TODO hypertile vae breaks for diffusers when using non-standard sizes + # TODO hypertile: vae breaks when using non-standard sizes if nh * nw > 1: x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) out = forward(x, *args[1:], **kwargs) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index d8d356071..66040c5aa 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -70,7 +70,7 @@ def hijack_ddpm_edit(): if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) - CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) + CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: (kwargs.update({'act_layer': GELUHijack}) and False) or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 # pylint: disable=unnecessary-lambda-assignment first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) # pylint: disable=unnecessary-lambda-assignment diff --git a/modules/sd_models.py b/modules/sd_models.py index cf1921a36..087e8ecfc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -13,10 +13,11 @@ from rich import progress # pylint: disable=redefined-builtin import torch import safetensors.torch +import accelerate from omegaconf import OmegaConf from ldm.util import instantiate_from_config from modules import paths, shared, shared_state, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, sd_models_config, sd_models_compile, sd_hijack_accelerate, sd_detect -from modules.timer import Timer +from modules.timer import Timer, process as process_timer from modules.memstats import memory_stats from modules.modeldata import model_data from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closet_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import @@ -33,6 +34,8 @@ debug_process = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None diffusers_version = int(diffusers.__version__.split('.')[1]) checkpoint_tiles = checkpoint_titles # legacy compatibility +should_offload = ['sc', 'sd3', 'f1', 'hunyuandit', 'auraflow', 'omnigen'] +offload_hook_instance = None class NoWatermark: @@ -279,7 +282,7 @@ def eval_model(model, op=None, sd_model=None): # pylint: disable=unused-argument model.eval() return model sd_model = sd_models_compile.apply_compile_to_model(sd_model, eval_model, ["Model", "VAE", "Text Encoder"], op="eval") - if len(shared.opts.torchao_quantization) > 0: + if len(shared.opts.torchao_quantization) > 0 and shared.opts.torchao_quantization_mode == 'post': sd_model = sd_models_compile.torchao_quantization(sd_model) if shared.opts.opt_channelslast and hasattr(sd_model, 'unet'): @@ -310,6 +313,7 @@ def set_accelerate(sd_model): def set_diffuser_offload(sd_model, op: str = 'model'): + t0 = time.time() if not shared.native: shared.log.warning('Attempting to use offload with backend=original') return @@ -318,13 +322,17 @@ def set_diffuser_offload(sd_model, op: str = 'model'): return if not (hasattr(sd_model, "has_accelerate") and sd_model.has_accelerate): sd_model.has_accelerate = False - if hasattr(sd_model, 'maybe_free_model_hooks') and shared.opts.diffusers_offload_mode == "none": - shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode}') - sd_model.maybe_free_model_hooks() - sd_model.has_accelerate = False - if hasattr(sd_model, "enable_model_cpu_offload") and shared.opts.diffusers_offload_mode == "model": + if shared.opts.diffusers_offload_mode == "none": + if shared.sd_model_type in should_offload: + shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} type={shared.sd_model.__class__.__name__} large model') + else: + shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') + if hasattr(sd_model, 'maybe_free_model_hooks'): + sd_model.maybe_free_model_hooks() + sd_model.has_accelerate = False + if shared.opts.diffusers_offload_mode == "model" and hasattr(sd_model, "enable_model_cpu_offload"): try: - shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode}') + shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: shared.opts.diffusers_move_base = False shared.opts.diffusers_move_unet = False @@ -337,9 +345,9 @@ def set_diffuser_offload(sd_model, op: str = 'model'): set_accelerate(sd_model) except Exception as e: shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') - if hasattr(sd_model, "enable_sequential_cpu_offload") and shared.opts.diffusers_offload_mode == "sequential": + if shared.opts.diffusers_offload_mode == "sequential" and hasattr(sd_model, "enable_sequential_cpu_offload"): try: - shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode}') + shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: shared.opts.diffusers_move_base = False shared.opts.diffusers_move_unet = False @@ -358,74 +366,159 @@ def set_diffuser_offload(sd_model, op: str = 'model'): except Exception as e: shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') if shared.opts.diffusers_offload_mode == "balanced": - try: - shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode}') - sd_model = apply_balanced_offload(sd_model) - except Exception as e: - shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') - - -def apply_balanced_offload(sd_model): - from accelerate import infer_auto_device_map, dispatch_model - from accelerate.hooks import add_hook_to_module, remove_hook_from_module, ModelHook + sd_model = apply_balanced_offload(sd_model) + process_timer.add('offload', time.time() - t0) + + +class OffloadHook(accelerate.hooks.ModelHook): + def __init__(self, checkpoint_name): + if shared.opts.diffusers_offload_max_gpu_memory > 1: + shared.opts.diffusers_offload_max_gpu_memory = 0.75 + if shared.opts.diffusers_offload_max_cpu_memory > 1: + shared.opts.diffusers_offload_max_cpu_memory = 0.75 + self.checkpoint_name = checkpoint_name + self.min_watermark = shared.opts.diffusers_offload_min_gpu_memory + self.max_watermark = shared.opts.diffusers_offload_max_gpu_memory + self.cpu_watermark = shared.opts.diffusers_offload_max_cpu_memory + self.gpu = int(shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory * 1024*1024*1024) + self.cpu = int(shared.cpu_memory * shared.opts.diffusers_offload_max_cpu_memory * 1024*1024*1024) + self.offload_map = {} + self.param_map = {} + gpu = f'{shared.gpu_memory * shared.opts.diffusers_offload_min_gpu_memory:.3f}-{shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory}:{shared.gpu_memory}' + shared.log.info(f'Offload: type=balanced op=init watermark={self.min_watermark}-{self.max_watermark} gpu={gpu} cpu={shared.cpu_memory:.3f} limit={shared.opts.cuda_mem_fraction:.2f}') + self.validate() + super().__init__() + + def validate(self): + if shared.opts.diffusers_offload_mode != 'balanced': + return + if shared.opts.diffusers_offload_min_gpu_memory < 0 or shared.opts.diffusers_offload_min_gpu_memory > 1: + shared.opts.diffusers_offload_min_gpu_memory = 0.25 + shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} invalid value') + if shared.opts.diffusers_offload_max_gpu_memory < 0.1 or shared.opts.diffusers_offload_max_gpu_memory > 1: + shared.opts.diffusers_offload_max_gpu_memory = 0.75 + shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} invalid value') + if shared.opts.diffusers_offload_min_gpu_memory > shared.opts.diffusers_offload_max_gpu_memory: + shared.opts.diffusers_offload_min_gpu_memory = shared.opts.diffusers_offload_max_gpu_memory + shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} reset') + if shared.opts.diffusers_offload_max_gpu_memory * shared.gpu_memory < 4: + shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} low memory') + + def model_size(self): + return sum(self.offload_map.values()) + + def init_hook(self, module): + return module + + def pre_forward(self, module, *args, **kwargs): + if devices.normalize_device(module.device) != devices.normalize_device(devices.device): + device_index = torch.device(devices.device).index + if device_index is None: + device_index = 0 + max_memory = { device_index: self.gpu, "cpu": self.cpu } + device_map = getattr(module, "balanced_offload_device_map", None) + if device_map is None or max_memory != getattr(module, "balanced_offload_max_memory", None): + device_map = accelerate.infer_auto_device_map(module, max_memory=max_memory) + offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__)) + module = accelerate.dispatch_model(module, device_map=device_map, offload_dir=offload_dir) + module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access + module.balanced_offload_device_map = device_map + module.balanced_offload_max_memory = max_memory + return args, kwargs + + def post_forward(self, module, output): + return output + + def detach_hook(self, module): + return module + + +def apply_balanced_offload(sd_model, exclude=[]): + global offload_hook_instance # pylint: disable=global-statement + if shared.opts.diffusers_offload_mode != "balanced": + return sd_model + t0 = time.time() excluded = ['OmniGenPipeline'] if sd_model.__class__.__name__ in excluded: return sd_model - - class dispatch_from_cpu_hook(ModelHook): - def init_hook(self, module): - return module - - def pre_forward(self, module, *args, **kwargs): - if devices.normalize_device(module.device) != devices.normalize_device(devices.device): - device_index = torch.device(devices.device).index - if device_index is None: - device_index = 0 - max_memory = { - device_index: f"{shared.opts.diffusers_offload_max_gpu_memory}GiB", - "cpu": f"{shared.opts.diffusers_offload_max_cpu_memory}GiB", - } - device_map = infer_auto_device_map(module, max_memory=max_memory) - module = remove_hook_from_module(module, recurse=True) - offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__)) - module = dispatch_model(module, device_map=device_map, offload_dir=offload_dir) - module = add_hook_to_module(module, dispatch_from_cpu_hook(), append=True) - module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access - return args, kwargs - - def post_forward(self, module, output): - return output - - def detach_hook(self, module): - return module + cached = True + checkpoint_name = sd_model.sd_checkpoint_info.name if getattr(sd_model, "sd_checkpoint_info", None) is not None else None + if checkpoint_name is None: + checkpoint_name = sd_model.__class__.__name__ + if offload_hook_instance is None or offload_hook_instance.min_watermark != shared.opts.diffusers_offload_min_gpu_memory or offload_hook_instance.max_watermark != shared.opts.diffusers_offload_max_gpu_memory or checkpoint_name != offload_hook_instance.checkpoint_name: + cached = False + offload_hook_instance = OffloadHook(checkpoint_name) + + def get_pipe_modules(pipe): + if hasattr(pipe, "_internal_dict"): + modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access + else: + modules_names = get_signature(pipe).keys() + modules_names = [m for m in modules_names if m not in exclude and not m.startswith('_')] + modules = {} + for module_name in modules_names: + module_size = offload_hook_instance.offload_map.get(module_name, None) + if module_size is None: + module = getattr(pipe, module_name, None) + if not isinstance(module, torch.nn.Module): + continue + try: + module_size = sum(p.numel() * p.element_size() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024 + param_num = sum(p.numel() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024 + except Exception as e: + shared.log.error(f'Offload: type=balanced op=calc module={module_name} {e}') + module_size = 0 + offload_hook_instance.offload_map[module_name] = module_size + offload_hook_instance.param_map[module_name] = param_num + modules[module_name] = module_size + modules = sorted(modules.items(), key=lambda x: x[1], reverse=True) + return modules def apply_balanced_offload_to_module(pipe): + used_gpu, used_ram = devices.torch_gc(fast=True) if hasattr(pipe, "pipe"): apply_balanced_offload_to_module(pipe.pipe) if hasattr(pipe, "_internal_dict"): keys = pipe._internal_dict.keys() # pylint: disable=protected-access else: - keys = get_signature(shared.sd_model).keys() - for module_name in keys: # pylint: disable=protected-access + keys = get_signature(pipe).keys() + keys = [k for k in keys if k not in exclude and not k.startswith('_')] + for module_name, module_size in get_pipe_modules(pipe): # pylint: disable=protected-access module = getattr(pipe, module_name, None) - if isinstance(module, torch.nn.Module): - checkpoint_name = pipe.sd_checkpoint_info.name if getattr(pipe, "sd_checkpoint_info", None) is not None else None - if checkpoint_name is None: - checkpoint_name = pipe.__class__.__name__ - offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name) - module = remove_hook_from_module(module, recurse=True) - try: - module = module.to("cpu") - module.offload_dir = offload_dir - network_layer_name = getattr(module, "network_layer_name", None) - module = add_hook_to_module(module, dispatch_from_cpu_hook(), append=True) - module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access - if network_layer_name: - module.network_layer_name = network_layer_name - except Exception as e: - if 'bitsandbytes' not in str(e): - shared.log.error(f'Balanced offload: module={module_name} {e}') - devices.torch_gc(fast=True) + if module is None: + continue + network_layer_name = getattr(module, "network_layer_name", None) + device_map = getattr(module, "balanced_offload_device_map", None) + max_memory = getattr(module, "balanced_offload_max_memory", None) + module = accelerate.hooks.remove_hook_from_module(module, recurse=True) + perc_gpu = used_gpu / shared.gpu_memory + try: + prev_gpu = used_gpu + do_offload = (perc_gpu > shared.opts.diffusers_offload_min_gpu_memory) and (module.device != devices.cpu) + if do_offload: + module = module.to(devices.cpu, non_blocking=True) + used_gpu -= module_size + if not cached: + shared.log.debug(f'Offload: type=balanced module={module_name} cls={module.__class__.__name__} dtype={module.dtype} quant={getattr(module, "quantization_method", None)} params={offload_hook_instance.param_map[module_name]:.3f} size={offload_hook_instance.offload_map[module_name]:.3f}') + debug_move(f'Offload: type=balanced op={"move" if do_offload else "skip"} gpu={prev_gpu:.3f}:{used_gpu:.3f} perc={perc_gpu:.2f} ram={used_ram:.3f} current={module.device} dtype={module.dtype} quant={getattr(module, "quantization_method", None)} module={module.__class__.__name__} size={module_size:.3f}') + except Exception as e: + if 'out of memory' in str(e): + devices.torch_gc(fast=True, force=True, reason='oom') + elif 'bitsandbytes' in str(e): + pass + else: + shared.log.error(f'Offload: type=balanced op=apply module={module_name} {e}') + if os.environ.get('SD_MOVE_DEBUG', None): + errors.display(e, f'Offload: type=balanced op=apply module={module_name}') + module.offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name) + module = accelerate.hooks.add_hook_to_module(module, offload_hook_instance, append=True) + module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access + if network_layer_name: + module.network_layer_name = network_layer_name + if device_map and max_memory: + module.balanced_offload_device_map = device_map + module.balanced_offload_max_memory = max_memory + devices.torch_gc(fast=True, force=True, reason='offload') apply_balanced_offload_to_module(sd_model) if hasattr(sd_model, "pipe"): @@ -435,6 +528,12 @@ def apply_balanced_offload_to_module(pipe): if hasattr(sd_model, "decoder_pipe"): apply_balanced_offload_to_module(sd_model.decoder_pipe) set_accelerate(sd_model) + t = time.time() - t0 + process_timer.add('offload', t) + fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access + debug_move(f'Apply offload: time={t:.2f} type=balanced fn={fn}') + if not cached: + shared.log.info(f'Offload: type=balanced op=apply class={sd_model.__class__.__name__} modules={len(offload_hook_instance.offload_map)} size={offload_hook_instance.model_size():.3f}') return sd_model @@ -479,7 +578,7 @@ def move_model(model, device=None, force=False): shared.log.error(f'Model move execution device: device={device} {e}') if getattr(model, 'has_accelerate', False) and not force: return - if hasattr(model, "device") and devices.normalize_device(model.device) == devices.normalize_device(device): + if hasattr(model, "device") and devices.normalize_device(model.device) == devices.normalize_device(device) and not force: return try: t0 = time.time() @@ -512,7 +611,10 @@ def move_model(model, device=None, force=False): except Exception as e1: t1 = time.time() shared.log.error(f'Model move: device={device} {e1}') - if os.environ.get('SD_MOVE_DEBUG', None) or (t1-t0) > 0.1: + if 'move' not in process_timer.records: + process_timer.records['move'] = 0 + process_timer.records['move'] += t1 - t0 + if os.environ.get('SD_MOVE_DEBUG', None) or (t1-t0) > 2: shared.log.debug(f'Model move: device={device} class={model.__class__.__name__} accelerate={getattr(model, "has_accelerate", False)} fn={fn} time={t1-t0:.2f}') # pylint: disable=protected-access devices.torch_gc() @@ -612,6 +714,9 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op=' elif model_type in ['PixArt-Sigma']: # forced pipeline from modules.model_pixart import load_pixart sd_model = load_pixart(checkpoint_info, diffusers_load_config) + elif model_type in ['Sana']: # forced pipeline + from modules.model_sana import load_sana + sd_model = load_sana(checkpoint_info, diffusers_load_config) elif model_type in ['Lumina-Next']: # forced pipeline from modules.model_lumina import load_lumina sd_model = load_lumina(checkpoint_info, diffusers_load_config) @@ -771,7 +876,7 @@ def load_diffuser_file(model_type, pipeline, checkpoint_info, diffusers_load_con return sd_model -def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'): # pylint: disable=unused-argument +def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model', revision=None): # pylint: disable=unused-argument if timer is None: timer = Timer() logging.getLogger("diffusers").setLevel(logging.ERROR) @@ -784,6 +889,8 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No "requires_safety_checker": False, # sd15 specific but we cant know ahead of time # "use_safetensors": True, } + if revision is not None: + diffusers_load_config['revision'] = revision if shared.opts.diffusers_model_load_variant != 'default': diffusers_load_config['variant'] = shared.opts.diffusers_model_load_variant if shared.opts.diffusers_pipeline == 'Custom Diffusers Pipeline' and len(shared.opts.custom_diffusers_pipeline) > 0: @@ -925,7 +1032,8 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No shared.log.error(f"Load {op}: {e}") errors.display(e, "Model") - devices.torch_gc(force=True) + if shared.opts.diffusers_offload_mode != 'balanced': + devices.torch_gc(force=True) if sd_model is not None: script_callbacks.model_loaded_callback(sd_model) @@ -961,6 +1069,11 @@ def get_signature(cls): return signature.parameters +def get_call(cls): + signature = inspect.signature(cls.__call__, follow_wrapped=True, eval_str=True) + return signature.parameters + + def switch_pipe(cls: diffusers.DiffusionPipeline, pipeline: diffusers.DiffusionPipeline = None, force = False, args = {}): """ args: @@ -1077,6 +1190,9 @@ def set_diffuser_pipe(pipe, new_pipe_type): 'OmniGenPipeline', 'StableDiffusion3ControlNetPipeline', 'InstantIRPipeline', + 'FluxFillPipeline', + 'FluxControlPipeline', + 'StableVideoDiffusionPipeline', ] n = getattr(pipe.__class__, '__name__', '') @@ -1345,7 +1461,7 @@ def reload_text_encoder(initial=False): set_t5(pipe=shared.sd_model, module='text_encoder_3', t5=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir) -def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model', force=False): +def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model', force=False, revision=None): load_dict = shared.opts.sd_model_dict != model_data.sd_dict from modules import lowvram, sd_hijack checkpoint_info = info or select_checkpoint(op=op) # are we selecting model or dictionary @@ -1380,7 +1496,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model', unload_model_weights(op=op) sd_model = None timer = Timer() - # TODO implement caching after diffusers implement state_dict loading + # TODO model loader: implement model in-memory caching state_dict = get_checkpoint_state_dict(checkpoint_info, timer) if not shared.native else None checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) timer.record("config") @@ -1390,7 +1506,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model', load_model(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op) model_data.sd_dict = shared.opts.sd_model_dict else: - load_diffuser(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op) + load_diffuser(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op, revision=revision) if load_dict and next_checkpoint_info is not None: model_data.sd_dict = shared.opts.sd_model_dict shared.opts.data["sd_model_checkpoint"] = next_checkpoint_info.title @@ -1441,10 +1557,17 @@ def disable_offload(sd_model): from accelerate.hooks import remove_hook_from_module if not getattr(sd_model, 'has_accelerate', False): return - if hasattr(sd_model, 'components'): - for _name, model in sd_model.components.items(): - if isinstance(model, torch.nn.Module): - remove_hook_from_module(model, recurse=True) + if hasattr(sd_model, "_internal_dict"): + keys = sd_model._internal_dict.keys() # pylint: disable=protected-access + else: + keys = get_signature(sd_model).keys() + for module_name in keys: # pylint: disable=protected-access + module = getattr(sd_model, module_name, None) + if isinstance(module, torch.nn.Module): + network_layer_name = getattr(module, "network_layer_name", None) + module = remove_hook_from_module(module, recurse=True) + if network_layer_name: + module.network_layer_name = network_layer_name sd_model.has_accelerate = False diff --git a/modules/sd_models_compile.py b/modules/sd_models_compile.py index 91ed84ded..20a7d7de2 100644 --- a/modules/sd_models_compile.py +++ b/modules/sd_models_compile.py @@ -47,7 +47,7 @@ def apply_compile_to_model(sd_model, function, options, op=None): sd_model.prior_pipe.prior.clip_txt_pooled_mapper = backup_clip_txt_pooled_mapper if "Text Encoder" in options: if hasattr(sd_model, 'text_encoder') and hasattr(sd_model.text_encoder, 'config'): - if hasattr(sd_model, 'decoder_pipe') and hasattr(sd_model.decoder_pipe, 'text_encoder'): + if hasattr(sd_model, 'decoder_pipe') and hasattr(sd_model.decoder_pipe, 'text_encoder') and hasattr(sd_model.decoder_pipe.text_encoder, 'config'): sd_model.decoder_pipe.text_encoder = function(sd_model.decoder_pipe.text_encoder, op="decoder_pipe.text_encoder", sd_model=sd_model) else: if op == "nncf" and sd_model.text_encoder.__class__.__name__ in {"T5EncoderModel", "UMT5EncoderModel"}: @@ -76,7 +76,7 @@ def apply_compile_to_model(sd_model, function, options, op=None): dtype=torch.float32 if devices.dtype != torch.bfloat16 else torch.bfloat16 ) sd_model.text_encoder_3 = function(sd_model.text_encoder_3, op="text_encoder_3", sd_model=sd_model) - if hasattr(sd_model, 'prior_pipe') and hasattr(sd_model.prior_pipe, 'text_encoder'): + if hasattr(sd_model, 'prior_pipe') and hasattr(sd_model.prior_pipe, 'text_encoder') and hasattr(sd_model.prior_pipe.text_encoder, 'config'): sd_model.prior_pipe.text_encoder = function(sd_model.prior_pipe.text_encoder, op="prior_pipe.text_encoder", sd_model=sd_model) if "VAE" in options: if hasattr(sd_model, 'vae') and hasattr(sd_model.vae, 'decode'): @@ -505,51 +505,26 @@ def compile_diffusers(sd_model): def torchao_quantization(sd_model): try: - install('torchao', quiet=True) + install('torchao==0.7.0', quiet=True) from torchao import quantization as q except Exception as e: shared.log.error(f"Quantization: type=TorchAO quantization not supported: {e}") return sd_model - if shared.opts.torchao_quantization_type == "int8+act": - fn = q.int8_dynamic_activation_int8_weight - elif shared.opts.torchao_quantization_type == "int8": - fn = q.int8_weight_only - elif shared.opts.torchao_quantization_type == "int4": - fn = q.int4_weight_only - elif shared.opts.torchao_quantization_type == "fp8+act": - fn = q.float8_dynamic_activation_float8_weight - elif shared.opts.torchao_quantization_type == "fp8": - fn = q.float8_weight_only - elif shared.opts.torchao_quantization_type == "fpx": - fn = q.fpx_weight_only - else: + + fn = getattr(q, shared.opts.torchao_quantization_type, None) + if fn is None: shared.log.error(f"Quantization: type=TorchAO type={shared.opts.torchao_quantization_type} not supported") return sd_model + def torchao_model(model, op=None, sd_model=None): # pylint: disable=unused-argument + q.quantize_(model, fn(), device=devices.device) + return model + shared.log.info(f"Quantization: type=TorchAO pipe={sd_model.__class__.__name__} quant={shared.opts.torchao_quantization_type} fn={fn} targets={shared.opts.torchao_quantization}") try: t0 = time.time() - modules = [] - if hasattr(sd_model, 'unet') and 'Model' in shared.opts.torchao_quantization: - modules.append('unet') - q.quantize_(sd_model.unet, fn(), device=devices.device) - if hasattr(sd_model, 'transformer') and 'Model' in shared.opts.torchao_quantization: - modules.append('transformer') - q.quantize_(sd_model.transformer, fn(), device=devices.device) - # sd_model.transformer = q.autoquant(sd_model.transformer, error_on_unseen=False) - if hasattr(sd_model, 'vae') and 'VAE' in shared.opts.torchao_quantization: - modules.append('vae') - q.quantize_(sd_model.vae, fn(), device=devices.device) - if hasattr(sd_model, 'text_encoder') and 'Text Encoder' in shared.opts.torchao_quantization: - modules.append('te1') - q.quantize_(sd_model.text_encoder, fn(), device=devices.device) - if hasattr(sd_model, 'text_encoder_2') and 'Text Encoder' in shared.opts.torchao_quantization: - modules.append('te2') - q.quantize_(sd_model.text_encoder_2, fn(), device=devices.device) - if hasattr(sd_model, 'text_encoder_3') and 'Text Encoder' in shared.opts.torchao_quantization: - modules.append('te3') - q.quantize_(sd_model.text_encoder_3, fn(), device=devices.device) + apply_compile_to_model(sd_model, torchao_model, shared.opts.torchao_quantization, op="torchao") t1 = time.time() - shared.log.info(f"Quantization: type=TorchAO modules={modules} time={t1-t0:.2f}") + shared.log.info(f"Quantization: type=TorchAO time={t1-t0:.2f}") except Exception as e: shared.log.error(f"Quantization: type=TorchAO {e}") setup_logging() # torchao uses dynamo which messes with logging so reset is needed diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index e560744dd..dc58a2419 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -3,6 +3,7 @@ from modules import shared from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # pylint: disable=unused-import + debug = shared.log.trace if os.environ.get('SD_SAMPLER_DEBUG', None) is not None else lambda *args, **kwargs: None debug('Trace: SAMPLER') all_samplers = [] @@ -47,6 +48,8 @@ def visible_sampler_names(): def create_sampler(name, model): + if name is None or name == 'None': + return model.scheduler try: current = model.scheduler.__class__.__name__ except Exception: @@ -73,15 +76,15 @@ def create_sampler(name, model): shared.log.debug(f'Sampler: sampler="{name}" config={config.options}') return sampler elif shared.native: - FlowModels = ['Flux', 'StableDiffusion3', 'Lumina', 'AuraFlow'] + FlowModels = ['Flux', 'StableDiffusion3', 'Lumina', 'AuraFlow', 'Sana'] if 'KDiffusion' in model.__class__.__name__: return None - if any(x in model.__class__.__name__ for x in FlowModels) and 'FlowMatch' not in name: - shared.log.warning(f'Sampler: default={current} target="{name}" class={model.__class__.__name__} linear scheduler unsupported') - return None if not any(x in model.__class__.__name__ for x in FlowModels) and 'FlowMatch' in name: shared.log.warning(f'Sampler: default={current} target="{name}" class={model.__class__.__name__} flow-match scheduler unsupported') return None + # if any(x in model.__class__.__name__ for x in FlowModels) and 'FlowMatch' not in name: + # shared.log.warning(f'Sampler: default={current} target="{name}" class={model.__class__.__name__} linear scheduler unsupported') + # return None sampler = config.constructor(model) if sampler is None: sampler = config.constructor(model) @@ -91,7 +94,8 @@ def create_sampler(name, model): if hasattr(model, "prior_pipe") and hasattr(model.prior_pipe, "scheduler"): model.prior_pipe.scheduler = sampler.sampler model.prior_pipe.scheduler.config.clip_sample = False - shared.log.debug(f'Sampler: sampler="{sampler.name}" class="{model.scheduler.__class__.__name__} config={sampler.config}') + clean_config = {k: v for k, v in sampler.config.items() if v is not None and v is not False} + shared.log.debug(f'Sampler: sampler="{sampler.name}" class="{model.scheduler.__class__.__name__} config={clean_config}') return sampler.sampler else: return None diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index a487fe9b7..a90ceec27 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,13 +1,15 @@ +import time import threading from collections import namedtuple import torch import torchvision.transforms as T from PIL import Image -from modules import shared, devices, processing, images, sd_vae_approx, sd_vae_taesd, sd_vae_stablecascade, sd_samplers +from modules import shared, devices, processing, images, sd_vae_approx, sd_vae_taesd, sd_vae_stablecascade, sd_samplers, timer SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) approximation_indexes = { "Simple": 0, "Approximate": 1, "TAESD": 2, "Full VAE": 3 } +flow_models = ['f1', 'sd3', 'lumina', 'auraflow', 'sana'] warned = False queue_lock = threading.Lock() @@ -33,13 +35,12 @@ def setup_img2img_steps(p, steps=None): def single_sample_to_image(sample, approximation=None): with queue_lock: - sd_cascade = False + t0 = time.time() if approximation is None: approximation = approximation_indexes.get(shared.opts.show_progress_type, None) if approximation is None: warn_once('Unknown decode type') approximation = 0 - # normal sample is [4,64,64] try: if sample.dtype == torch.bfloat16 and (approximation == 0 or approximation == 1): sample = sample.to(torch.float16) @@ -48,22 +49,15 @@ def single_sample_to_image(sample, approximation=None): if len(sample.shape) > 4: # likely unknown video latent (e.g. svd) return Image.new(mode="RGB", size=(512, 512)) - if len(sample) == 16: # sd_cascade - sd_cascade = True if len(sample.shape) == 4 and sample.shape[0]: # likely animatediff latent sample = sample.permute(1, 0, 2, 3)[0] - if shared.native: # [-x,x] to [-5,5] - sample_max = torch.max(sample) - if sample_max > 5: - sample = sample * (5 / sample_max) - sample_min = torch.min(sample) - if sample_min < -5: - sample = sample * (5 / abs(sample_min)) - if approximation == 2: # TAESD + if shared.opts.live_preview_downscale and (sample.shape[-1] > 128 or sample.shape[-2] > 128): + scale = 128 / max(sample.shape[-1], sample.shape[-2]) + sample = torch.nn.functional.interpolate(sample.unsqueeze(0), scale_factor=[scale, scale], mode='bilinear', align_corners=False)[0] x_sample = sd_vae_taesd.decode(sample) x_sample = (1.0 + x_sample) / 2.0 # preview requires smaller range - elif sd_cascade and approximation != 3: + elif shared.sd_model_type == 'sc' and approximation != 3: x_sample = sd_vae_stablecascade.decode(sample) elif approximation == 0: # Simple x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5 @@ -84,6 +78,8 @@ def single_sample_to_image(sample, approximation=None): except Exception as e: warn_once(f'Preview: {e}') image = Image.new(mode="RGB", size=(512, 512)) + t1 = time.time() + timer.process.add('preview', t1 - t0) return image diff --git a/modules/sd_samplers_diffusers.py b/modules/sd_samplers_diffusers.py index 60c75b64e..6c05b2045 100644 --- a/modules/sd_samplers_diffusers.py +++ b/modules/sd_samplers_diffusers.py @@ -4,13 +4,12 @@ import inspect import diffusers from modules import shared, errors -from modules import sd_samplers_common +from modules.sd_samplers_common import SamplerData, flow_models debug = shared.log.trace if os.environ.get('SD_SAMPLER_DEBUG', None) is not None else lambda *args, **kwargs: None debug('Trace: SAMPLER') - try: from diffusers import ( CMStochasticIterativeScheduler, @@ -52,6 +51,8 @@ from modules.schedulers.scheduler_dc import DCSolverMultistepScheduler # pylint: disable=ungrouped-imports from modules.schedulers.scheduler_vdm import VDMScheduler # pylint: disable=ungrouped-imports from modules.schedulers.scheduler_dpm_flowmatch import FlowMatchDPMSolverMultistepScheduler # pylint: disable=ungrouped-imports + from modules.schedulers.scheduler_bdia import BDIA_DDIMScheduler # pylint: disable=ungrouped-imports + from modules.schedulers.scheduler_ufogen import UFOGenScheduler # pylint: disable=ungrouped-imports except Exception as e: shared.log.error(f'Diffusers import error: version={diffusers.__version__} error: {e}') if os.environ.get('SD_SAMPLER_DEBUG', None) is not None: @@ -62,41 +63,43 @@ # prediction_type is ideally set in model as well, but it maybe needed that we do auto-detect of model type in the future 'All': { 'num_train_timesteps': 1000, 'beta_start': 0.0001, 'beta_end': 0.02, 'beta_schedule': 'linear', 'prediction_type': 'epsilon' }, - 'UniPC': { 'predict_x0': True, 'sample_max_value': 1.0, 'solver_order': 2, 'solver_type': 'bh2', 'thresholding': False, 'use_beta_sigmas': False, 'use_exponential_sigmas': False, 'use_karras_sigmas': False, 'lower_order_final': True, 'timestep_spacing': 'linspace', 'final_sigmas_type': 'zero', 'rescale_betas_zero_snr': False }, + 'UniPC': { 'predict_x0': True, 'sample_max_value': 1.0, 'solver_order': 2, 'solver_type': 'bh2', 'thresholding': False, 'use_beta_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_karras_sigmas': False, 'lower_order_final': True, 'timestep_spacing': 'linspace', 'final_sigmas_type': 'zero', 'rescale_betas_zero_snr': False }, 'DDIM': { 'clip_sample': False, 'set_alpha_to_one': True, 'steps_offset': 0, 'clip_sample_range': 1.0, 'sample_max_value': 1.0, 'timestep_spacing': 'leading', 'rescale_betas_zero_snr': False, 'thresholding': False }, 'Euler': { 'steps_offset': 0, 'interpolation_type': "linear", 'rescale_betas_zero_snr': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'use_beta_sigmas': False, 'use_exponential_sigmas': False, 'use_karras_sigmas': False }, 'Euler a': { 'steps_offset': 0, 'rescale_betas_zero_snr': False, 'timestep_spacing': 'linspace' }, 'Euler SGM': { 'steps_offset': 0, 'interpolation_type': "linear", 'rescale_betas_zero_snr': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'trailing', 'use_beta_sigmas': False, 'use_exponential_sigmas': False, 'use_karras_sigmas': False, 'prediction_type': "sample" }, 'Euler EDM': { 'sigma_schedule': "karras" }, - 'Euler FlowMatch': { 'timestep_spacing': "linspace", 'shift': 1, 'use_dynamic_shifting': False }, + 'Euler FlowMatch': { 'timestep_spacing': "linspace", 'shift': 1, 'use_dynamic_shifting': False, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False }, - 'DPM++': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'final_sigmas_type': 'sigma_min' }, - 'DPM++ 1S': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 1 }, - 'DPM++ 2M': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 2 }, - 'DPM++ 3M': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 3 }, - 'DPM++ 2M SDE': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "sde-dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 2 }, + 'DPM++': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'final_sigmas_type': 'sigma_min' }, + 'DPM++ 1S': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 1 }, + 'DPM++ 2M': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 2 }, + 'DPM++ 3M': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 3 }, + 'DPM++ 2M SDE': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "sde-dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 2 }, 'DPM++ 2M EDM': { 'solver_order': 2, 'solver_type': 'midpoint', 'final_sigmas_type': 'zero', 'algorithm_type': 'dpmsolver++' }, 'DPM++ Cosine': { 'solver_order': 2, 'sigma_schedule': "exponential", 'prediction_type': "v-prediction" }, 'DPM SDE': { 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'noise_sampler_seed': None, 'timestep_spacing': 'linspace', 'steps_offset': 0, }, - 'DPM2 FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver2', 'use_noise_sampler': True }, - 'DPM2a FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver2A', 'use_noise_sampler': True }, - 'DPM2++ 2M FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2M', 'use_noise_sampler': True }, - 'DPM2++ 2S FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2S', 'use_noise_sampler': True }, - 'DPM2++ SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++sde', 'use_noise_sampler': True }, - 'DPM2++ 2M SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2Msde', 'use_noise_sampler': True }, - 'DPM2++ 3M SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 3, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++3Msde', 'use_noise_sampler': True }, + 'DPM2 FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver2', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 }, + 'DPM2a FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver2A', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 }, + 'DPM2++ 2M FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2M', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 }, + 'DPM2++ 2S FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2S', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 }, + 'DPM2++ SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++sde', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 }, + 'DPM2++ 2M SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2Msde', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 }, + 'DPM2++ 3M SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 3, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++3Msde', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 }, 'Heun': { 'use_beta_sigmas': False, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'timestep_spacing': 'linspace' }, 'Heun FlowMatch': { 'timestep_spacing': "linspace", 'shift': 1 }, - 'DEIS': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "deis", 'solver_type': "logrho", 'lower_order_final': True, 'timestep_spacing': 'linspace', 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False }, - 'SA Solver': {'predictor_order': 2, 'corrector_order': 2, 'thresholding': False, 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'timestep_spacing': 'linspace'}, + 'DEIS': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "deis", 'solver_type': "logrho", 'lower_order_final': True, 'timestep_spacing': 'linspace', 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False }, + 'SA Solver': {'predictor_order': 2, 'corrector_order': 2, 'thresholding': False, 'lower_order_final': True, 'use_karras_sigmas': False, 'use_flow_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'timestep_spacing': 'linspace'}, 'DC Solver': { 'beta_start': 0.0001, 'beta_end': 0.02, 'solver_order': 2, 'prediction_type': "epsilon", 'thresholding': False, 'solver_type': 'bh2', 'lower_order_final': True, 'dc_order': 2, 'disable_corrector': [0] }, 'VDM Solver': { 'clip_sample_range': 2.0, }, 'LCM': { 'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': "scaled_linear", 'set_alpha_to_one': True, 'rescale_betas_zero_snr': False, 'thresholding': False, 'timestep_spacing': 'linspace' }, 'TCD': { 'set_alpha_to_one': True, 'rescale_betas_zero_snr': False, 'beta_schedule': 'scaled_linear' }, + 'UFOGen': {}, + 'BDIA DDIM': { 'clip_sample': False, 'set_alpha_to_one': True, 'steps_offset': 0, 'clip_sample_range': 1.0, 'sample_max_value': 1.0, 'timestep_spacing': 'leading', 'rescale_betas_zero_snr': False, 'thresholding': False, 'gamma': 1.0 }, 'PNDM': { 'skip_prk_steps': False, 'set_alpha_to_one': False, 'steps_offset': 0, 'timestep_spacing': 'linspace' }, 'IPNDM': { }, @@ -108,53 +111,55 @@ } samplers_data_diffusers = [ - sd_samplers_common.SamplerData('Default', None, [], {}), + SamplerData('Default', None, [], {}), - sd_samplers_common.SamplerData('UniPC', lambda model: DiffusionSampler('UniPC', UniPCMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DDIM', lambda model: DiffusionSampler('DDIM', DDIMScheduler, model), [], {}), - sd_samplers_common.SamplerData('Euler', lambda model: DiffusionSampler('Euler', EulerDiscreteScheduler, model), [], {}), - sd_samplers_common.SamplerData('Euler a', lambda model: DiffusionSampler('Euler a', EulerAncestralDiscreteScheduler, model), [], {}), - sd_samplers_common.SamplerData('Euler SGM', lambda model: DiffusionSampler('Euler SGM', EulerDiscreteScheduler, model), [], {}), - sd_samplers_common.SamplerData('Euler EDM', lambda model: DiffusionSampler('Euler EDM', EDMEulerScheduler, model), [], {}), - sd_samplers_common.SamplerData('Euler FlowMatch', lambda model: DiffusionSampler('Euler FlowMatch', FlowMatchEulerDiscreteScheduler, model), [], {}), + SamplerData('UniPC', lambda model: DiffusionSampler('UniPC', UniPCMultistepScheduler, model), [], {}), + SamplerData('DDIM', lambda model: DiffusionSampler('DDIM', DDIMScheduler, model), [], {}), + SamplerData('Euler', lambda model: DiffusionSampler('Euler', EulerDiscreteScheduler, model), [], {}), + SamplerData('Euler a', lambda model: DiffusionSampler('Euler a', EulerAncestralDiscreteScheduler, model), [], {}), + SamplerData('Euler SGM', lambda model: DiffusionSampler('Euler SGM', EulerDiscreteScheduler, model), [], {}), + SamplerData('Euler EDM', lambda model: DiffusionSampler('Euler EDM', EDMEulerScheduler, model), [], {}), + SamplerData('Euler FlowMatch', lambda model: DiffusionSampler('Euler FlowMatch', FlowMatchEulerDiscreteScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM++', lambda model: DiffusionSampler('DPM++', DPMSolverSinglestepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM++ 1S', lambda model: DiffusionSampler('DPM++ 1S', DPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM++ 2M', lambda model: DiffusionSampler('DPM++ 2M', DPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM++ 3M', lambda model: DiffusionSampler('DPM++ 3M', DPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM++ 2M SDE', lambda model: DiffusionSampler('DPM++ 2M SDE', DPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM++ 2M EDM', lambda model: DiffusionSampler('DPM++ 2M EDM', EDMDPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM++ Cosine', lambda model: DiffusionSampler('DPM++ 2M EDM', CosineDPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM SDE', lambda model: DiffusionSampler('DPM SDE', DPMSolverSDEScheduler, model), [], {}), + SamplerData('DPM++', lambda model: DiffusionSampler('DPM++', DPMSolverSinglestepScheduler, model), [], {}), + SamplerData('DPM++ 1S', lambda model: DiffusionSampler('DPM++ 1S', DPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM++ 2M', lambda model: DiffusionSampler('DPM++ 2M', DPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM++ 3M', lambda model: DiffusionSampler('DPM++ 3M', DPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM++ 2M SDE', lambda model: DiffusionSampler('DPM++ 2M SDE', DPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM++ 2M EDM', lambda model: DiffusionSampler('DPM++ 2M EDM', EDMDPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM++ Cosine', lambda model: DiffusionSampler('DPM++ 2M EDM', CosineDPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM SDE', lambda model: DiffusionSampler('DPM SDE', DPMSolverSDEScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM2 FlowMatch', lambda model: DiffusionSampler('DPM2 FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM2a FlowMatch', lambda model: DiffusionSampler('DPM2a FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM2++ 2M FlowMatch', lambda model: DiffusionSampler('DPM2++ 2M FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM2++ 2S FlowMatch', lambda model: DiffusionSampler('DPM2++ 2S FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM2++ SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM2++ 2M SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ 2M SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('DPM2++ 3M SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ 3M SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM2 FlowMatch', lambda model: DiffusionSampler('DPM2 FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM2a FlowMatch', lambda model: DiffusionSampler('DPM2a FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM2++ 2M FlowMatch', lambda model: DiffusionSampler('DPM2++ 2M FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM2++ 2S FlowMatch', lambda model: DiffusionSampler('DPM2++ 2S FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM2++ SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM2++ 2M SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ 2M SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), + SamplerData('DPM2++ 3M SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ 3M SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('Heun', lambda model: DiffusionSampler('Heun', HeunDiscreteScheduler, model), [], {}), - sd_samplers_common.SamplerData('Heun FlowMatch', lambda model: DiffusionSampler('Heun FlowMatch', FlowMatchHeunDiscreteScheduler, model), [], {}), + SamplerData('Heun', lambda model: DiffusionSampler('Heun', HeunDiscreteScheduler, model), [], {}), + SamplerData('Heun FlowMatch', lambda model: DiffusionSampler('Heun FlowMatch', FlowMatchHeunDiscreteScheduler, model), [], {}), - sd_samplers_common.SamplerData('DEIS', lambda model: DiffusionSampler('DEIS', DEISMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('SA Solver', lambda model: DiffusionSampler('SA Solver', SASolverScheduler, model), [], {}), - sd_samplers_common.SamplerData('DC Solver', lambda model: DiffusionSampler('DC Solver', DCSolverMultistepScheduler, model), [], {}), - sd_samplers_common.SamplerData('VDM Solver', lambda model: DiffusionSampler('VDM Solver', VDMScheduler, model), [], {}), + SamplerData('DEIS', lambda model: DiffusionSampler('DEIS', DEISMultistepScheduler, model), [], {}), + SamplerData('SA Solver', lambda model: DiffusionSampler('SA Solver', SASolverScheduler, model), [], {}), + SamplerData('DC Solver', lambda model: DiffusionSampler('DC Solver', DCSolverMultistepScheduler, model), [], {}), + SamplerData('VDM Solver', lambda model: DiffusionSampler('VDM Solver', VDMScheduler, model), [], {}), + SamplerData('BDIA DDIM', lambda model: DiffusionSampler('BDIA DDIM g=0', BDIA_DDIMScheduler, model), [], {}), - sd_samplers_common.SamplerData('PNDM', lambda model: DiffusionSampler('PNDM', PNDMScheduler, model), [], {}), - sd_samplers_common.SamplerData('IPNDM', lambda model: DiffusionSampler('IPNDM', IPNDMScheduler, model), [], {}), - sd_samplers_common.SamplerData('DDPM', lambda model: DiffusionSampler('DDPM', DDPMScheduler, model), [], {}), - sd_samplers_common.SamplerData('LMSD', lambda model: DiffusionSampler('LMSD', LMSDiscreteScheduler, model), [], {}), - sd_samplers_common.SamplerData('KDPM2', lambda model: DiffusionSampler('KDPM2', KDPM2DiscreteScheduler, model), [], {}), - sd_samplers_common.SamplerData('KDPM2 a', lambda model: DiffusionSampler('KDPM2 a', KDPM2AncestralDiscreteScheduler, model), [], {}), - sd_samplers_common.SamplerData('CMSI', lambda model: DiffusionSampler('CMSI', CMStochasticIterativeScheduler, model), [], {}), + SamplerData('PNDM', lambda model: DiffusionSampler('PNDM', PNDMScheduler, model), [], {}), + SamplerData('IPNDM', lambda model: DiffusionSampler('IPNDM', IPNDMScheduler, model), [], {}), + SamplerData('DDPM', lambda model: DiffusionSampler('DDPM', DDPMScheduler, model), [], {}), + SamplerData('LMSD', lambda model: DiffusionSampler('LMSD', LMSDiscreteScheduler, model), [], {}), + SamplerData('KDPM2', lambda model: DiffusionSampler('KDPM2', KDPM2DiscreteScheduler, model), [], {}), + SamplerData('KDPM2 a', lambda model: DiffusionSampler('KDPM2 a', KDPM2AncestralDiscreteScheduler, model), [], {}), + SamplerData('CMSI', lambda model: DiffusionSampler('CMSI', CMStochasticIterativeScheduler, model), [], {}), - sd_samplers_common.SamplerData('LCM', lambda model: DiffusionSampler('LCM', LCMScheduler, model), [], {}), - sd_samplers_common.SamplerData('TCD', lambda model: DiffusionSampler('TCD', TCDScheduler, model), [], {}), + SamplerData('LCM', lambda model: DiffusionSampler('LCM', LCMScheduler, model), [], {}), + SamplerData('TCD', lambda model: DiffusionSampler('TCD', TCDScheduler, model), [], {}), + SamplerData('UFOGen', lambda model: DiffusionSampler('UFOGen', UFOGenScheduler, model), [], {}), - sd_samplers_common.SamplerData('Same as primary', None, [], {}), + SamplerData('Same as primary', None, [], {}), ] @@ -175,14 +180,14 @@ def __init__(self, name, constructor, model, **kwargs): orig_config = model.default_scheduler.scheduler_config else: orig_config = model.default_scheduler.config - for key, value in config.get(name, {}).items(): # apply diffusers per-scheduler defaults - self.config[key] = value debug(f'Sampler: diffusers="{self.config}"') debug(f'Sampler: original="{orig_config}"') for key, value in orig_config.items(): # apply model defaults if key in self.config: self.config[key] = value debug(f'Sampler: default="{self.config}"') + for key, value in config.get(name, {}).items(): # apply diffusers per-scheduler defaults + self.config[key] = value for key, value in kwargs.items(): # apply user args, if any if key in self.config: self.config[key] = value @@ -200,16 +205,20 @@ def __init__(self, name, constructor, model, **kwargs): timesteps = re.split(',| ', shared.opts.schedulers_timesteps) timesteps = [int(x) for x in timesteps if x.isdigit()] if len(timesteps) == 0: - if 'use_beta_sigmas' in self.config: - self.config['use_beta_sigmas'] = shared.opts.schedulers_sigma == 'beta' - if 'use_karras_sigmas' in self.config: - self.config['use_karras_sigmas'] = shared.opts.schedulers_sigma == 'karras' - if 'use_exponential_sigmas' in self.config: - self.config['use_exponential_sigmas'] = shared.opts.schedulers_sigma == 'exponential' - if 'use_lu_lambdas' in self.config: - self.config['use_lu_lambdas'] = shared.opts.schedulers_sigma == 'lambdas' if 'sigma_schedule' in self.config: self.config['sigma_schedule'] = shared.opts.schedulers_sigma if shared.opts.schedulers_sigma != 'default' else None + if shared.opts.schedulers_sigma == 'default' and shared.sd_model_type in flow_models and 'use_flow_sigmas' in self.config: + self.config['use_flow_sigmas'] = True + elif shared.opts.schedulers_sigma == 'betas' and 'use_beta_sigmas' in self.config: + self.config['use_beta_sigmas'] = True + elif shared.opts.schedulers_sigma == 'karras' and 'use_karras_sigmas' in self.config: + self.config['use_karras_sigmas'] = True + elif shared.opts.schedulers_sigma == 'flowmatch' and 'use_flow_sigmas' in self.config: + self.config['use_flow_sigmas'] = True + elif shared.opts.schedulers_sigma == 'exponential' and 'use_exponential_sigmas' in self.config: + self.config['use_exponential_sigmas'] = True + elif shared.opts.schedulers_sigma == 'lambdas' and 'use_lu_lambdas' in self.config: + self.config['use_lu_lambdas'] = True else: pass # timesteps are set using set_timesteps in set_pipeline_args @@ -236,7 +245,7 @@ def __init__(self, name, constructor, model, **kwargs): if 'use_dynamic_shifting' in self.config: if 'Flux' in model.__class__.__name__: self.config['use_dynamic_shifting'] = shared.opts.schedulers_dynamic_shift - if 'use_beta_sigmas' in self.config: + if 'use_beta_sigmas' in self.config and 'sigma_schedule' in self.config: self.config['use_beta_sigmas'] = 'StableDiffusion3' in model.__class__.__name__ if 'rescale_betas_zero_snr' in self.config: self.config['rescale_betas_zero_snr'] = shared.opts.schedulers_rescale_betas diff --git a/modules/sd_unet.py b/modules/sd_unet.py index f730bdb74..deb0b24b0 100644 --- a/modules/sd_unet.py +++ b/modules/sd_unet.py @@ -36,7 +36,7 @@ def load_unet(model): model.prior_pipe.text_encoder = prior_text_encoder.to(devices.device, dtype=devices.dtype) elif "Flux" in model.__class__.__name__ or "StableDiffusion3" in model.__class__.__name__: loaded_unet = shared.opts.sd_unet - sd_models.load_diffuser() # TODO forcing reloading entire model as loading transformers only leads to massive memory usage + sd_models.load_diffuser() # TODO model load: force-reloading entire model as loading transformers only leads to massive memory usage """ from modules.model_flux import load_transformer transformer = load_transformer(unet_dict[shared.opts.sd_unet]) diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 4d213ad48..a1959817c 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -169,6 +169,9 @@ def decode(latents): if vae is None: return latents try: + size = max(latents.shape[-1], latents.shape[-2]) + if size > 256: + return latents with devices.inference_context(): latents = latents.detach().clone().to(devices.device, dtype) if len(latents.shape) == 3: diff --git a/modules/segmoe/segmoe_model.py b/modules/segmoe/segmoe_model.py index 4b96527be..a542c1fbb 100644 --- a/modules/segmoe/segmoe_model.py +++ b/modules/segmoe/segmoe_model.py @@ -136,7 +136,7 @@ def __init__(self, config_or_path, **kwargs) -> Any: memory_format=torch.channels_last, ) - def to(self, *args, **kwargs): # TODO added no-op to avoid error + def to(self, *args, **kwargs): self.pipe.to(*args, **kwargs) def load_from_scratch(self, config: str, **kwargs) -> None: @@ -202,7 +202,6 @@ def load_from_scratch(self, config: str, **kwargs) -> None: self.config["down_idx_start"] = self.down_idx_start self.config["down_idx_end"] = self.down_idx_end - # TODO: Add Support for Scheduler Selection self.pipe.scheduler = DDPMScheduler.from_config(self.pipe.scheduler.config) # Load Experts @@ -242,7 +241,6 @@ def load_from_scratch(self, config: str, **kwargs) -> None: **kwargs, ) - # TODO: Add Support for Scheduler Selection expert.scheduler = DDPMScheduler.from_config( expert.scheduler.config ) diff --git a/modules/shared.py b/modules/shared.py index a89cbbc95..31429e076 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -16,11 +16,11 @@ import orjson import diffusers from rich.console import Console -from modules import errors, devices, shared_items, shared_state, cmd_args, theme, history +from modules import errors, devices, shared_items, shared_state, cmd_args, theme, history, files_cache from modules.paths import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611 from modules.dml import memory_providers, default_memory_provider, directml_do_hijack from modules.onnx_impl import initialize_onnx, execution_providers -from modules.memstats import memory_stats +from modules.memstats import memory_stats, ram_stats # pylint: disable=unused-import from modules.ui_components import DropdownEditable import modules.interrogate import modules.memmon @@ -132,7 +132,8 @@ def readfile(filename, silent=False, lock=False): # data = json.loads(data) t1 = time.time() if not silent: - log.debug(f'Read: file="{filename}" json={len(data)} bytes={os.path.getsize(filename)} time={t1-t0:.3f}') + fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access + log.debug(f'Read: file="{filename}" json={len(data)} bytes={os.path.getsize(filename)} time={t1-t0:.3f} fn={fn}') except FileNotFoundError as err: log.debug(f'Reading failed: {filename} {err}') except Exception as err: @@ -237,6 +238,8 @@ def default(obj): mem_stat = memory_stats() gpu_memory = mem_stat['gpu']['total'] if "gpu" in mem_stat else 0 native = backend == Backend.DIFFUSERS +if not files_cache.do_cache_folders: + log.warning('File cache disabled: ') class OptionInfo: @@ -363,7 +366,7 @@ def list_samplers(): def temp_disable_extensions(): disable_safe = ['sd-webui-controlnet', 'multidiffusion-upscaler-for-automatic1111', 'a1111-sd-webui-lycoris', 'sd-webui-agent-scheduler', 'clip-interrogator-ext', 'stable-diffusion-webui-rembg', 'sd-extension-chainner', 'stable-diffusion-webui-images-browser'] - disable_diffusers = ['sd-webui-controlnet', 'multidiffusion-upscaler-for-automatic1111', 'a1111-sd-webui-lycoris', 'sd-webui-animatediff'] + disable_diffusers = ['sd-webui-controlnet', 'multidiffusion-upscaler-for-automatic1111', 'a1111-sd-webui-lycoris', 'sd-webui-animatediff', 'Lora'] disable_themes = ['sd-webui-lobe-theme', 'cozy-nest', 'sdnext-modernui'] disable_original = [] disabled = [] @@ -431,15 +434,15 @@ def get_default_modes(): cmd_opts.lowvram = True default_offload_mode = "sequential" log.info(f"Device detect: memory={gpu_memory:.1f} optimization=lowvram") - elif gpu_memory <= 8: - cmd_opts.medvram = True - default_offload_mode = "model" - log.info(f"Device detect: memory={gpu_memory:.1f} optimization=medvram") + # elif gpu_memory <= 8: + # cmd_opts.medvram = True + # default_offload_mode = "model" + # log.info(f"Device detect: memory={gpu_memory:.1f} optimization=medvram") else: - default_offload_mode = "none" - log.info(f"Device detect: memory={gpu_memory:.1f} optimization=none") + default_offload_mode = "balanced" + log.info(f"Device detect: memory={gpu_memory:.1f} optimization=balanced") elif cmd_opts.medvram: - default_offload_mode = "model" + default_offload_mode = "balanced" elif cmd_opts.lowvram: default_offload_mode = "sequential" @@ -449,16 +452,16 @@ def get_default_modes(): default_cross_attention = "Scaled-Dot-Product" if native else "Doggettx's" elif devices.backend == "mps": default_cross_attention = "Scaled-Dot-Product" if native else "Doggettx's" - else: # cuda, rocm, ipex, openvino - default_cross_attention ="Scaled-Dot-Product" + else: # cuda, rocm, zluda, ipex, openvino + default_cross_attention = "Scaled-Dot-Product" if devices.backend == "rocm": default_sdp_options = ['Memory attention', 'Math attention'] elif devices.backend == "zluda": - default_sdp_options = ['Math attention'] + default_sdp_options = ['Math attention', 'Dynamic attention'] else: default_sdp_options = ['Flash attention', 'Memory attention', 'Math attention'] - if (cmd_opts.lowvram or cmd_opts.medvram) and ('Flash attention' not in default_sdp_options): + if (cmd_opts.lowvram or cmd_opts.medvram) and ('Flash attention' not in default_sdp_options and 'Dynamic attention' not in default_sdp_options): default_sdp_options.append('Dynamic attention') return default_offload_mode, default_cross_attention, default_sdp_options @@ -466,165 +469,165 @@ def get_default_modes(): startup_offload_mode, startup_cross_attention, startup_sdp_options = get_default_modes() -options_templates.update(options_section(('sd', "Execution & Models"), { +options_templates.update(options_section(('sd', "Models & Loading"), { "sd_backend": OptionInfo(default_backend, "Execution backend", gr.Radio, {"choices": ["diffusers", "original"] }), + "diffusers_pipeline": OptionInfo('Autodetect', 'Model pipeline', gr.Dropdown, lambda: {"choices": list(shared_items.get_pipelines()), "visible": native}), "sd_model_checkpoint": OptionInfo(default_checkpoint, "Base model", DropdownEditable, lambda: {"choices": list_checkpoint_titles()}, refresh=refresh_checkpoints), "sd_model_refiner": OptionInfo('None', "Refiner model", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_titles()}, refresh=refresh_checkpoints), - "sd_vae": OptionInfo("Automatic", "VAE model", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list), "sd_unet": OptionInfo("None", "UNET model", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list), - "sd_text_encoder": OptionInfo('None', "Text encoder model", gr.Dropdown, lambda: {"choices": shared_items.sd_te_items()}, refresh=shared_items.refresh_te_list), - "sd_model_dict": OptionInfo('None', "Use separate base dict", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_titles()}, refresh=refresh_checkpoints), + "latent_history": OptionInfo(16, "Latent history size", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), + + "offload_sep": OptionInfo("

Model Offloading

", "", gr.HTML), + "diffusers_move_base": OptionInfo(False, "Move base model to CPU when using refiner", gr.Checkbox, {"visible": False }), + "diffusers_move_unet": OptionInfo(False, "Move base model to CPU when using VAE", gr.Checkbox, {"visible": False }), + "diffusers_move_refiner": OptionInfo(False, "Move refiner model to CPU when not in use", gr.Checkbox, {"visible": False }), + "diffusers_extract_ema": OptionInfo(False, "Use model EMA weights when possible", gr.Checkbox, {"visible": False }), + "diffusers_offload_mode": OptionInfo(startup_offload_mode, "Model offload mode", gr.Radio, {"choices": ['none', 'balanced', 'model', 'sequential']}), + "diffusers_offload_min_gpu_memory": OptionInfo(0.25, "Balanced offload GPU low watermark", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01 }), + "diffusers_offload_max_gpu_memory": OptionInfo(0.70, "Balanced offload GPU high watermark", gr.Slider, {"minimum": 0.1, "maximum": 1, "step": 0.01 }), + "diffusers_offload_max_cpu_memory": OptionInfo(0.90, "Balanced offload CPU high watermark", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": False }), + + "advanced_sep": OptionInfo("

Advanced Options

", "", gr.HTML), "sd_checkpoint_autoload": OptionInfo(True, "Model autoload on start"), "sd_checkpoint_autodownload": OptionInfo(True, "Model auto-download on demand"), - "sd_textencoder_cache": OptionInfo(True, "Cache text encoder results", gr.Checkbox, {"visible": False}), - "sd_textencoder_cache_size": OptionInfo(4, "Text encoder results LRU cache size", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "stream_load": OptionInfo(False, "Load models using stream loading method", gr.Checkbox, {"visible": not native }), + "stream_load": OptionInfo(False, "Model load using streams", gr.Checkbox), + "diffusers_eval": OptionInfo(True, "Force model eval", gr.Checkbox, {"visible": False }), + "diffusers_to_gpu": OptionInfo(False, "Load model directly to GPU"), + "disable_accelerate": OptionInfo(False, "Disable accelerate", gr.Checkbox, {"visible": False }), + "sd_model_dict": OptionInfo('None', "Use separate base dict", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_titles()}, refresh=refresh_checkpoints), + "sd_checkpoint_cache": OptionInfo(0, "Cached models", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1, "visible": not native }), +})) + +options_templates.update(options_section(('vae_encoder', "Variable Auto Encoder"), { + "sd_vae": OptionInfo("Automatic", "VAE model", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list), + "diffusers_vae_upcast": OptionInfo("default", "VAE upcasting", gr.Radio, {"choices": ['default', 'true', 'false']}), + "no_half_vae": OptionInfo(False if not cmd_opts.use_openvino else True, "Full precision (--no-half-vae)"), + "diffusers_vae_slicing": OptionInfo(True, "VAE slicing", gr.Checkbox, {"visible": native}), + "diffusers_vae_tiling": OptionInfo(cmd_opts.lowvram or cmd_opts.medvram, "VAE tiling", gr.Checkbox, {"visible": native}), + "sd_vae_sliced_encode": OptionInfo(False, "VAE sliced encode", gr.Checkbox, {"visible": not native}), + "nan_skip": OptionInfo(False, "Skip Generation if NaN found in latents", gr.Checkbox), + "rollback_vae": OptionInfo(False, "Attempt VAE roll back for NaN values"), +})) + +options_templates.update(options_section(('text_encoder', "Text Encoder"), { + "sd_text_encoder": OptionInfo('None', "Text encoder model", gr.Dropdown, lambda: {"choices": shared_items.sd_te_items()}, refresh=shared_items.refresh_te_list), + "prompt_attention": OptionInfo("native", "Prompt attention parser", gr.Radio, {"choices": ["native", "compel", "xhinker", "a1111", "fixed"] }), "prompt_mean_norm": OptionInfo(False, "Prompt attention normalization", gr.Checkbox), + "sd_textencoder_cache": OptionInfo(True, "Cache text encoder results", gr.Checkbox, {"visible": False}), + "sd_textencoder_cache_size": OptionInfo(4, "Text encoder cache size", gr.Slider, {"minimum": 0, "maximum": 16, "step": 1}), "comma_padding_backtrack": OptionInfo(20, "Prompt padding", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1, "visible": not native }), - "prompt_attention": OptionInfo("native", "Prompt attention parser", gr.Radio, {"choices": ["native", "compel", "xhinker", "a1111", "fixed"] }), - "latent_history": OptionInfo(16, "Latent history size", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), - "sd_checkpoint_cache": OptionInfo(0, "Cached models", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1, "visible": not native }), + "diffusers_zeros_prompt_pad": OptionInfo(False, "Use zeros for prompt padding", gr.Checkbox), + "diffusers_pooled": OptionInfo("default", "Diffusers SDXL pooled embeds", gr.Radio, {"choices": ['default', 'weighted']}), })) options_templates.update(options_section(('cuda', "Compute Settings"), { - "math_sep": OptionInfo("

Execution precision

", "", gr.HTML), + "math_sep": OptionInfo("

Execution Precision

", "", gr.HTML), "precision": OptionInfo("Autocast", "Precision type", gr.Radio, {"choices": ["Autocast", "Full"]}), "cuda_dtype": OptionInfo("Auto", "Device precision type", gr.Radio, {"choices": ["Auto", "FP32", "FP16", "BF16"]}), + "no_half": OptionInfo(False if not cmd_opts.use_openvino else True, "Full precision (--no-half)", None, None, None), + "upcast_sampling": OptionInfo(False if sys.platform != "darwin" else True, "Upcast sampling", gr.Checkbox, {"visible": not native}), + "upcast_attn": OptionInfo(False, "Upcast attention layer", gr.Checkbox, {"visible": not native}), + "cuda_cast_unet": OptionInfo(False, "Fixed UNet precision", gr.Checkbox, {"visible": not native}), - "model_sep": OptionInfo("

Model options

", "", gr.HTML), - "no_half": OptionInfo(False if not cmd_opts.use_openvino else True, "Full precision for model (--no-half)", None, None, None), - "no_half_vae": OptionInfo(False if not cmd_opts.use_openvino else True, "Full precision for VAE (--no-half-vae)"), - "upcast_sampling": OptionInfo(False if sys.platform != "darwin" else True, "Upcast sampling"), - "upcast_attn": OptionInfo(False, "Upcast attention layer"), - "cuda_cast_unet": OptionInfo(False, "Fixed UNet precision"), - "nan_skip": OptionInfo(False, "Skip Generation if NaN found in latents", gr.Checkbox), - "rollback_vae": OptionInfo(False, "Attempt VAE roll back for NaN values"), + "generator_sep": OptionInfo("

Noise Options

", "", gr.HTML), + "diffusers_generator_device": OptionInfo("GPU", "Generator device", gr.Radio, {"choices": ["GPU", "CPU", "Unset"]}), "cross_attention_sep": OptionInfo("

Cross Attention

", "", gr.HTML), - "cross_attention_optimization": OptionInfo(startup_cross_attention, "Attention optimization method", gr.Radio, lambda: {"choices": shared_items.list_crossattention(native) }), - "sdp_options": OptionInfo(startup_sdp_options, "SDP options", gr.CheckboxGroup, {"choices": ['Flash attention', 'Memory attention', 'Math attention', 'Dynamic attention', 'Sage attention'] }), + "cross_attention_optimization": OptionInfo(startup_cross_attention, "Attention optimization method", gr.Radio, lambda: {"choices": shared_items.list_crossattention(native)}), + "sdp_options": OptionInfo(startup_sdp_options, "SDP options", gr.CheckboxGroup, {"choices": ['Flash attention', 'Memory attention', 'Math attention', 'Dynamic attention', 'Sage attention'], "visible": native}), "xformers_options": OptionInfo(['Flash attention'], "xFormers options", gr.CheckboxGroup, {"choices": ['Flash attention'] }), "dynamic_attention_slice_rate": OptionInfo(4, "Dynamic Attention slicing rate in GB", gr.Slider, {"minimum": 0.1, "maximum": gpu_memory, "step": 0.1, "visible": native}), "sub_quad_sep": OptionInfo("

Sub-quadratic options

", "", gr.HTML, {"visible": not native}), "sub_quad_q_chunk_size": OptionInfo(512, "Attention query chunk size", gr.Slider, {"minimum": 16, "maximum": 8192, "step": 8, "visible": not native}), "sub_quad_kv_chunk_size": OptionInfo(512, "Attention kv chunk size", gr.Slider, {"minimum": 0, "maximum": 8192, "step": 8, "visible": not native}), "sub_quad_chunk_threshold": OptionInfo(80, "Attention chunking threshold", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1, "visible": not native}), +})) - "other_sep": OptionInfo("

Execution options

", "", gr.HTML), - "opt_channelslast": OptionInfo(False, "Use channels last "), - "cudnn_deterministic": OptionInfo(False, "Use deterministic mode"), - "cudnn_benchmark": OptionInfo(False, "Full-depth cuDNN benchmark feature"), +options_templates.update(options_section(('backends', "Backend Settings"), { + "other_sep": OptionInfo("

Torch Options

", "", gr.HTML), + "opt_channelslast": OptionInfo(False, "Channels last "), + "cudnn_deterministic": OptionInfo(False, "Deterministic mode"), + "cudnn_benchmark": OptionInfo(False, "Full-depth cuDNN benchmark"), "diffusers_fuse_projections": OptionInfo(False, "Fused projections"), - "torch_expandable_segments": OptionInfo(False, "Torch expandable segments"), - "cuda_mem_fraction": OptionInfo(0.0, "Torch memory limit", gr.Slider, {"minimum": 0, "maximum": 2.0, "step": 0.05}), - "torch_gc_threshold": OptionInfo(80, "Torch memory threshold for GC", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}), - "torch_malloc": OptionInfo("native", "Torch memory allocator", gr.Radio, {"choices": ['native', 'cudaMallocAsync'] }), + "torch_expandable_segments": OptionInfo(False, "Expandable segments"), + "cuda_mem_fraction": OptionInfo(0.0, "Memory limit", gr.Slider, {"minimum": 0, "maximum": 2.0, "step": 0.05}), + "torch_gc_threshold": OptionInfo(70, "GC threshold", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}), + "inference_mode": OptionInfo("no-grad", "Inference mode", gr.Radio, {"choices": ["no-grad", "inference-mode", "none"]}), + "torch_malloc": OptionInfo("native", "Memory allocator", gr.Radio, {"choices": ['native', 'cudaMallocAsync'] }), + + "onnx_sep": OptionInfo("

ONNX

", "", gr.HTML), + "onnx_execution_provider": OptionInfo(execution_providers.get_default_execution_provider().value, 'ONNX Execution Provider', gr.Dropdown, lambda: {"choices": execution_providers.available_execution_providers }), + "onnx_cpu_fallback": OptionInfo(True, 'ONNX allow fallback to CPU'), + "onnx_cache_converted": OptionInfo(True, 'ONNX cache converted models'), + "onnx_unload_base": OptionInfo(False, 'ONNX unload base model when processing refiner'), - "cuda_compile_sep": OptionInfo("

Model Compile

", "", gr.HTML), - "cuda_compile": OptionInfo([] if not cmd_opts.use_openvino else ["Model", "VAE"], "Compile Model", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "Upscaler"]}), - "cuda_compile_backend": OptionInfo("none" if not cmd_opts.use_openvino else "openvino_fx", "Model compile backend", gr.Radio, {"choices": ['none', 'inductor', 'cudagraphs', 'aot_ts_nvfuser', 'hidet', 'migraphx', 'ipex', 'onediff', 'stable-fast', 'deep-cache', 'olive-ai', 'openvino_fx']}), - "cuda_compile_mode": OptionInfo("default", "Model compile mode", gr.Radio, {"choices": ['default', 'reduce-overhead', 'max-autotune', 'max-autotune-no-cudagraphs']}), - "cuda_compile_fullgraph": OptionInfo(True if not cmd_opts.use_openvino else False, "Model compile fullgraph"), - "cuda_compile_precompile": OptionInfo(False, "Model compile precompile"), - "cuda_compile_verbose": OptionInfo(False, "Model compile verbose mode"), - "cuda_compile_errors": OptionInfo(True, "Model compile suppress errors"), - "deep_cache_interval": OptionInfo(3, "DeepCache cache interval", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), + "olive_sep": OptionInfo("

Olive

", "", gr.HTML), + "olive_float16": OptionInfo(True, 'Olive use FP16 on optimization'), + "olive_vae_encoder_float32": OptionInfo(False, 'Olive force FP32 for VAE Encoder'), + "olive_static_dims": OptionInfo(True, 'Olive use static dimensions'), + "olive_cache_optimized": OptionInfo(True, 'Olive cache optimized models'), "ipex_sep": OptionInfo("

IPEX

", "", gr.HTML, {"visible": devices.backend == "ipex"}), - "ipex_optimize": OptionInfo([], "IPEX Optimize for Intel GPUs", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "Upscaler"], "visible": devices.backend == "ipex"}), + "ipex_optimize": OptionInfo([], "IPEX Optimize", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "Upscaler"], "visible": devices.backend == "ipex"}), "openvino_sep": OptionInfo("

OpenVINO

", "", gr.HTML, {"visible": cmd_opts.use_openvino}), "openvino_devices": OptionInfo([], "OpenVINO devices to use", gr.CheckboxGroup, {"choices": get_openvino_device_list() if cmd_opts.use_openvino else [], "visible": cmd_opts.use_openvino}), # pylint: disable=E0606 "openvino_accuracy": OptionInfo("performance", "OpenVINO accuracy mode", gr.Radio, {"choices": ['performance', 'accuracy'], "visible": cmd_opts.use_openvino}), - "openvino_disable_model_caching": OptionInfo(False, "OpenVINO disable model caching", gr.Checkbox, {"visible": cmd_opts.use_openvino}), + "openvino_disable_model_caching": OptionInfo(True, "OpenVINO disable model caching", gr.Checkbox, {"visible": cmd_opts.use_openvino}), "openvino_disable_memory_cleanup": OptionInfo(True, "OpenVINO disable memory cleanup after compile", gr.Checkbox, {"visible": cmd_opts.use_openvino}), "directml_sep": OptionInfo("

DirectML

", "", gr.HTML, {"visible": devices.backend == "directml"}), "directml_memory_provider": OptionInfo(default_memory_provider, 'DirectML memory stats provider', gr.Radio, {"choices": memory_providers, "visible": devices.backend == "directml"}), "directml_catch_nan": OptionInfo(False, "DirectML retry ops for NaN", gr.Checkbox, {"visible": devices.backend == "directml"}), - - "olive_sep": OptionInfo("

Olive

", "", gr.HTML), - "olive_float16": OptionInfo(True, 'Olive use FP16 on optimization'), - "olive_vae_encoder_float32": OptionInfo(False, 'Olive force FP32 for VAE Encoder'), - "olive_static_dims": OptionInfo(True, 'Olive use static dimensions'), - "olive_cache_optimized": OptionInfo(True, 'Olive cache optimized models'), -})) - -options_templates.update(options_section(('diffusers', "Diffusers Settings"), { - "diffusers_pipeline": OptionInfo('Autodetect', 'Diffusers pipeline', gr.Dropdown, lambda: {"choices": list(shared_items.get_pipelines()) }), - "diffuser_cache_config": OptionInfo(True, "Use cached model config when available"), - "diffusers_move_base": OptionInfo(False, "Move base model to CPU when using refiner"), - "diffusers_move_unet": OptionInfo(False, "Move base model to CPU when using VAE"), - "diffusers_move_refiner": OptionInfo(False, "Move refiner model to CPU when not in use"), - "diffusers_extract_ema": OptionInfo(False, "Use model EMA weights when possible"), - "diffusers_generator_device": OptionInfo("GPU", "Generator device", gr.Radio, {"choices": ["GPU", "CPU", "Unset"]}), - "diffusers_offload_mode": OptionInfo(startup_offload_mode, "Model offload mode", gr.Radio, {"choices": ['none', 'balanced', 'model', 'sequential']}), - "diffusers_offload_max_gpu_memory": OptionInfo(round(gpu_memory * 0.75, 1), "Max GPU memory for balanced offload mode in GB", gr.Slider, {"minimum": 0, "maximum": gpu_memory, "step": 0.01,}), - "diffusers_offload_max_cpu_memory": OptionInfo(round(cpu_memory * 0.75, 1), "Max CPU memory for balanced offload mode in GB", gr.Slider, {"minimum": 0, "maximum": cpu_memory, "step": 0.01,}), - "diffusers_vae_upcast": OptionInfo("default", "VAE upcasting", gr.Radio, {"choices": ['default', 'true', 'false']}), - "diffusers_vae_slicing": OptionInfo(True, "VAE slicing"), - "diffusers_vae_tiling": OptionInfo(cmd_opts.lowvram or cmd_opts.medvram, "VAE tiling"), - "diffusers_model_load_variant": OptionInfo("default", "Preferred Model variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}), - "diffusers_vae_load_variant": OptionInfo("default", "Preferred VAE variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}), - "custom_diffusers_pipeline": OptionInfo('', 'Load custom Diffusers pipeline'), - "diffusers_eval": OptionInfo(True, "Force model eval"), - "diffusers_to_gpu": OptionInfo(False, "Load model directly to GPU"), - "disable_accelerate": OptionInfo(False, "Disable accelerate"), - "diffusers_pooled": OptionInfo("default", "Diffusers SDXL pooled embeds", gr.Radio, {"choices": ['default', 'weighted']}), - "diffusers_zeros_prompt_pad": OptionInfo(False, "Use zeros for prompt padding", gr.Checkbox), - "huggingface_token": OptionInfo('', 'HuggingFace token'), - "enable_linfusion": OptionInfo(False, "Apply LinFusion distillation on load"), - - "onnx_sep": OptionInfo("

ONNX Runtime

", "", gr.HTML), - "onnx_execution_provider": OptionInfo(execution_providers.get_default_execution_provider().value, 'Execution Provider', gr.Dropdown, lambda: {"choices": execution_providers.available_execution_providers }), - "onnx_cpu_fallback": OptionInfo(True, 'ONNX allow fallback to CPU'), - "onnx_cache_converted": OptionInfo(True, 'ONNX cache converted models'), - "onnx_unload_base": OptionInfo(False, 'ONNX unload base model when processing refiner'), })) options_templates.update(options_section(('quantization', "Quantization Settings"), { - "bnb_quantization": OptionInfo([], "BnB quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": native}), - "bnb_quantization_type": OptionInfo("nf4", "BnB quantization type", gr.Radio, {"choices": ['nf4', 'fp8', 'fp4'], "visible": native}), - "bnb_quantization_storage": OptionInfo("uint8", "BnB quantization storage", gr.Radio, {"choices": ["float16", "float32", "int8", "uint8", "float64", "bfloat16"], "visible": native}), - "optimum_quanto_weights": OptionInfo([], "Optimum.quanto quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "ControlNet"], "visible": native}), - "optimum_quanto_weights_type": OptionInfo("qint8", "Optimum.quanto quantization type", gr.Radio, {"choices": ['qint8', 'qfloat8_e4m3fn', 'qfloat8_e5m2', 'qint4', 'qint2'], "visible": native}), - "optimum_quanto_activations_type": OptionInfo("none", "Optimum.quanto quantization activations ", gr.Radio, {"choices": ['none', 'qint8', 'qfloat8_e4m3fn', 'qfloat8_e5m2'], "visible": native}), - "torchao_quantization": OptionInfo([], "TorchAO quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": native}), - "torchao_quantization_type": OptionInfo("int8", "TorchAO quantization type", gr.Radio, {"choices": ["int8+act", "int8", "int4", "fp8+act", "fp8", "fpx"], "visible": native}), - "nncf_compress_weights": OptionInfo([], "NNCF compression enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "ControlNet"], "visible": native}), - "nncf_compress_weights_mode": OptionInfo("INT8", "NNCF compress mode", gr.Radio, {"choices": ['INT8', 'INT8_SYM', 'INT4_ASYM', 'INT4_SYM', 'NF4'] if cmd_opts.use_openvino else ['INT8']}), - "nncf_compress_weights_raito": OptionInfo(1.0, "NNCF compress ratio", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": cmd_opts.use_openvino}), - "nncf_quantize": OptionInfo([], "NNCF OpenVINO quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": cmd_opts.use_openvino}), - "nncf_quant_mode": OptionInfo("INT8", "NNCF OpenVINO quantization mode", gr.Radio, {"choices": ['INT8', 'FP8_E4M3', 'FP8_E5M2'], "visible": cmd_opts.use_openvino}), - - "quant_shuffle_weights": OptionInfo(False, "Shuffle the weights between GPU and CPU when quantizing", gr.Checkbox, {"visible": native}), + "bnb_sep": OptionInfo("

BitsAndBytes

", "", gr.HTML), + "bnb_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": native}), + "bnb_quantization_type": OptionInfo("nf4", "Quantization type", gr.Dropdown, {"choices": ['nf4', 'fp8', 'fp4'], "visible": native}), + "bnb_quantization_storage": OptionInfo("uint8", "Backend storage", gr.Dropdown, {"choices": ["float16", "float32", "int8", "uint8", "float64", "bfloat16"], "visible": native}), + "optimum_quanto_sep": OptionInfo("

Optimum Quanto

", "", gr.HTML), + "optimum_quanto_weights": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "ControlNet"], "visible": native}), + "optimum_quanto_weights_type": OptionInfo("qint8", "Quantization weights type", gr.Dropdown, {"choices": ['qint8', 'qfloat8_e4m3fn', 'qfloat8_e5m2', 'qint4', 'qint2'], "visible": native}), + "optimum_quanto_activations_type": OptionInfo("none", "Quantization activations type ", gr.Dropdown, {"choices": ['none', 'qint8', 'qfloat8_e4m3fn', 'qfloat8_e5m2'], "visible": native}), + "torchao_sep": OptionInfo("

TorchAO

", "", gr.HTML), + "torchao_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": native}), + "torchao_quantization_mode": OptionInfo("pre", "Quantization mode", gr.Dropdown, {"choices": ['pre', 'post'], "visible": native}), + "torchao_quantization_type": OptionInfo("int8_weight_only", "Quantization type", gr.Dropdown, {"choices": ['int4_weight_only', 'int8_dynamic_activation_int4_weight', 'int8_weight_only', 'int8_dynamic_activation_int8_weight', 'float8_weight_only', 'float8_dynamic_activation_float8_weight', 'float8_static_activation_float8_weight'], "visible": native}), + "nncf_sep": OptionInfo("

NNCF

", "", gr.HTML), + "nncf_compress_weights": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "ControlNet"], "visible": native}), + "nncf_compress_weights_mode": OptionInfo("INT8", "Quantization type", gr.Dropdown, {"choices": ['INT8', 'INT8_SYM', 'INT4_ASYM', 'INT4_SYM', 'NF4'] if cmd_opts.use_openvino else ['INT8']}), + "nncf_compress_weights_raito": OptionInfo(1.0, "Compress ratio", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": cmd_opts.use_openvino}), + "nncf_quantize": OptionInfo([], "OpenVINO enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": cmd_opts.use_openvino}), + "nncf_quant_mode": OptionInfo("INT8", "OpenVINO mode", gr.Dropdown, {"choices": ['INT8', 'FP8_E4M3', 'FP8_E5M2'], "visible": cmd_opts.use_openvino}), + "quant_shuffle_weights": OptionInfo(False, "Shuffle weights", gr.Checkbox, {"visible": native}), })) -options_templates.update(options_section(('advanced', "Inference Settings"), { - "token_merging_sep": OptionInfo("

Token merging

", "", gr.HTML), +options_templates.update(options_section(('advanced', "Pipeline Modifiers"), { + "token_merging_sep": OptionInfo("

Token Merging

", "", gr.HTML), "token_merging_method": OptionInfo("None", "Token merging method", gr.Radio, {"choices": ['None', 'ToMe', 'ToDo']}), "tome_ratio": OptionInfo(0.0, "ToMe token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05}), "todo_ratio": OptionInfo(0.0, "ToDo token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05}), "freeu_sep": OptionInfo("

FreeU

", "", gr.HTML), "freeu_enabled": OptionInfo(False, "FreeU"), - "freeu_b1": OptionInfo(1.2, "1st stage backbone factor", gr.Slider, {"minimum": 1.0, "maximum": 2.0, "step": 0.01}), - "freeu_b2": OptionInfo(1.4, "2nd stage backbone factor", gr.Slider, {"minimum": 1.0, "maximum": 2.0, "step": 0.01}), - "freeu_s1": OptionInfo(0.9, "1st stage skip factor", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "freeu_s2": OptionInfo(0.2, "2nd stage skip factor", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "freeu_b1": OptionInfo(1.2, "1st stage backbone", gr.Slider, {"minimum": 1.0, "maximum": 2.0, "step": 0.01}), + "freeu_b2": OptionInfo(1.4, "2nd stage backbone", gr.Slider, {"minimum": 1.0, "maximum": 2.0, "step": 0.01}), + "freeu_s1": OptionInfo(0.9, "1st stage skip", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "freeu_s2": OptionInfo(0.2, "2nd stage skip", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "pag_sep": OptionInfo("

Perturbed-Attention Guidance

", "", gr.HTML), "pag_apply_layers": OptionInfo("m0", "PAG layer names"), "hypertile_sep": OptionInfo("

HyperTile

", "", gr.HTML), - "hypertile_hires_only": OptionInfo(False, "HyperTile hires pass only"), - "hypertile_unet_enabled": OptionInfo(False, "HyperTile UNet"), - "hypertile_unet_tile": OptionInfo(0, "HyperTile UNet tile size", gr.Slider, {"minimum": 0, "maximum": 1024, "step": 8}), - "hypertile_unet_swap_size": OptionInfo(1, "HyperTile UNet swap size", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), - "hypertile_unet_depth": OptionInfo(0, "HyperTile UNet depth", gr.Slider, {"minimum": 0, "maximum": 4, "step": 1}), - "hypertile_vae_enabled": OptionInfo(False, "HyperTile VAE", gr.Checkbox), - "hypertile_vae_tile": OptionInfo(128, "HyperTile VAE tile size", gr.Slider, {"minimum": 0, "maximum": 1024, "step": 8}), - "hypertile_vae_swap_size": OptionInfo(1, "HyperTile VAE swap size", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), + "hypertile_hires_only": OptionInfo(False, "HiRes pass only"), + "hypertile_unet_enabled": OptionInfo(False, "UNet Enabled"), + "hypertile_unet_tile": OptionInfo(0, "UNet tile size", gr.Slider, {"minimum": 0, "maximum": 1024, "step": 8}), + "hypertile_unet_swap_size": OptionInfo(1, "UNet swap size", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), + "hypertile_unet_depth": OptionInfo(0, "UNet depth", gr.Slider, {"minimum": 0, "maximum": 4, "step": 1}), + "hypertile_vae_enabled": OptionInfo(False, "VAE Enabled", gr.Checkbox), + "hypertile_vae_tile": OptionInfo(128, "VAE tile size", gr.Slider, {"minimum": 0, "maximum": 1024, "step": 8}), + "hypertile_vae_swap_size": OptionInfo(1, "VAE swap size", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), "hidiffusion_sep": OptionInfo("

HiDiffusion

", "", gr.HTML), "hidiffusion_raunet": OptionInfo(True, "Apply RAU-Net"), @@ -633,16 +636,28 @@ def get_default_modes(): "hidiffusion_t1": OptionInfo(-1, "Override T1 ratio", gr.Slider, {"minimum": -1, "maximum": 1.0, "step": 0.05}), "hidiffusion_t2": OptionInfo(-1, "Override T2 ratio", gr.Slider, {"minimum": -1, "maximum": 1.0, "step": 0.05}), + "linfusion_sep": OptionInfo("

Batch

", "", gr.HTML), + "enable_linfusion": OptionInfo(False, "Apply LinFusion distillation on load"), + "inference_batch_sep": OptionInfo("

Batch

", "", gr.HTML), "sequential_seed": OptionInfo(True, "Batch mode uses sequential seeds"), "batch_frame_mode": OptionInfo(False, "Parallel process images in batch"), - "inference_other_sep": OptionInfo("

Other

", "", gr.HTML), - "inference_mode": OptionInfo("no-grad", "Torch inference mode", gr.Radio, {"choices": ["no-grad", "inference-mode", "none"]}), - "sd_vae_sliced_encode": OptionInfo(False, "VAE sliced encode", gr.Checkbox, {"visible": not native}), +})) + +options_templates.update(options_section(('compile', "Model Compile"), { + "cuda_compile_sep": OptionInfo("

Model Compile

", "", gr.HTML), + "cuda_compile": OptionInfo([] if not cmd_opts.use_openvino else ["Model", "VAE"], "Compile Model", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "Upscaler"]}), + "cuda_compile_backend": OptionInfo("none" if not cmd_opts.use_openvino else "openvino_fx", "Model compile backend", gr.Radio, {"choices": ['none', 'inductor', 'cudagraphs', 'aot_ts_nvfuser', 'hidet', 'migraphx', 'ipex', 'onediff', 'stable-fast', 'deep-cache', 'olive-ai', 'openvino_fx']}), + "cuda_compile_mode": OptionInfo("default", "Model compile mode", gr.Radio, {"choices": ['default', 'reduce-overhead', 'max-autotune', 'max-autotune-no-cudagraphs']}), + "cuda_compile_fullgraph": OptionInfo(True if not cmd_opts.use_openvino else False, "Model compile fullgraph"), + "cuda_compile_precompile": OptionInfo(False, "Model compile precompile"), + "cuda_compile_verbose": OptionInfo(False, "Model compile verbose mode"), + "cuda_compile_errors": OptionInfo(True, "Model compile suppress errors"), + "deep_cache_interval": OptionInfo(3, "DeepCache cache interval", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), })) options_templates.update(options_section(('system-paths', "System Paths"), { - "models_paths_sep_options": OptionInfo("

Models paths

", "", gr.HTML), + "models_paths_sep_options": OptionInfo("

Models Paths

", "", gr.HTML), "models_dir": OptionInfo('models', "Base path where all models are stored", folder=True), "ckpt_dir": OptionInfo(os.path.join(paths.models_path, 'Stable-diffusion'), "Folder with stable diffusion models", folder=True), "diffusers_dir": OptionInfo(os.path.join(paths.models_path, 'Diffusers'), "Folder with Huggingface models", folder=True), @@ -723,13 +738,13 @@ def get_default_modes(): })) options_templates.update(options_section(('saving-paths', "Image Paths"), { - "saving_sep_images": OptionInfo("

Save options

", "", gr.HTML), + "saving_sep_images": OptionInfo("

Save Options

", "", gr.HTML), "save_images_add_number": OptionInfo(True, "Numbered filenames", component_args=hide_dirs), "use_original_name_batch": OptionInfo(True, "Batch uses original name"), "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"), "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs), "samples_filename_pattern": OptionInfo("[seq]-[model_name]-[prompt_words]", "Images filename pattern", component_args=hide_dirs), - "directories_max_prompt_words": OptionInfo(8, "Max words per pattern", gr.Slider, {"minimum": 1, "maximum": 99, "step": 1, **hide_dirs}), + "directories_max_prompt_words": OptionInfo(8, "Max words", gr.Slider, {"minimum": 1, "maximum": 99, "step": 1, **hide_dirs}), "outdir_sep_dirs": OptionInfo("

Folders

", "", gr.HTML), "outdir_samples": OptionInfo("", "Images folder", component_args=hide_dirs, folder=True), @@ -748,14 +763,14 @@ def get_default_modes(): "outdir_control_grids": OptionInfo("outputs/grids", 'Folder for control grids', component_args=hide_dirs, folder=True), })) -options_templates.update(options_section(('ui', "User Interface Options"), { +options_templates.update(options_section(('ui', "User Interface"), { "theme_type": OptionInfo("Standard", "Theme type", gr.Radio, {"choices": ["Modern", "Standard", "None"]}), "theme_style": OptionInfo("Auto", "Theme mode", gr.Radio, {"choices": ["Auto", "Dark", "Light"]}), "gradio_theme": OptionInfo("black-teal", "UI theme", gr.Dropdown, lambda: {"choices": theme.list_themes()}, refresh=theme.refresh_themes), "autolaunch": OptionInfo(False, "Autolaunch browser upon startup"), "font_size": OptionInfo(14, "Font size", gr.Slider, {"minimum": 8, "maximum": 32, "step": 1, "visible": True}), "aspect_ratios": OptionInfo("1:1, 4:3, 3:2, 16:9, 16:10, 21:9, 2:3, 3:4, 9:16, 10:16, 9:21", "Allowed aspect ratios"), - "motd": OptionInfo(True, "Show MOTD"), + "motd": OptionInfo(False, "Show MOTD"), "compact_view": OptionInfo(False, "Compact view"), "return_grid": OptionInfo(True, "Show grid in results"), "return_mask": OptionInfo(False, "Inpainting include greyscale mask in results"), @@ -767,14 +782,15 @@ def get_default_modes(): })) options_templates.update(options_section(('live-preview', "Live Previews"), { - "notification_audio_enable": OptionInfo(False, "Play a notification upon completion"), - "notification_audio_path": OptionInfo("html/notification.mp3","Path to notification sound", component_args=hide_dirs, folder=True), "show_progress_every_n_steps": OptionInfo(1, "Live preview display period", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "show_progress_type": OptionInfo("Approximate", "Live preview method", gr.Radio, {"choices": ["Simple", "Approximate", "TAESD", "Full VAE"]}), "live_preview_refresh_period": OptionInfo(500, "Progress update period", gr.Slider, {"minimum": 0, "maximum": 5000, "step": 25}), "live_preview_taesd_layers": OptionInfo(3, "TAESD decode layers", gr.Slider, {"minimum": 1, "maximum": 3, "step": 1}), + "live_preview_downscale": OptionInfo(True, "Downscale high resolution live previews"), "logmonitor_show": OptionInfo(True, "Show log view"), "logmonitor_refresh_period": OptionInfo(5000, "Log view update period", gr.Slider, {"minimum": 0, "maximum": 30000, "step": 25}), + "notification_audio_enable": OptionInfo(False, "Play a notification upon completion"), + "notification_audio_path": OptionInfo("html/notification.mp3","Path to notification sound", component_args=hide_dirs, folder=True), })) options_templates.update(options_section(('sampler-params', "Sampler Settings"), { @@ -813,7 +829,7 @@ def get_default_modes(): 's_noise': OptionInfo(1.0, "Sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01, "visible": not native}), 's_min': OptionInfo(0.0, "Sigma min", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01, "visible": not native}), 's_max': OptionInfo(0.0, "Sigma max", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 1.0, "visible": not native}), - "schedulers_sep_compvis": OptionInfo("

CompVis specific config

", "", gr.HTML, {"visible": not native}), + "schedulers_sep_compvis": OptionInfo("

CompVis Config

", "", gr.HTML, {"visible": not native}), 'uni_pc_variant': OptionInfo("bh2", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"], "visible": not native}), 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"], "visible": not native}), "ddim_discretize": OptionInfo('uniform', "DDIM discretize img2img", gr.Radio, {"choices": ['uniform', 'quad'], "visible": not native}), @@ -821,9 +837,9 @@ def get_default_modes(): options_templates.update(options_section(('postprocessing', "Postprocessing"), { 'postprocessing_enable_in_main_ui': OptionInfo([], "Additional postprocessing operations", gr.Dropdown, lambda: {"multiselect":True, "choices": [x.name for x in shared_items.postprocessing_scripts()]}), - 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", gr.Dropdown, lambda: {"multiselect":True, "choices": [x.name for x in shared_items.postprocessing_scripts()]}), + 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", gr.Dropdown, lambda: {"multiselect":True, "choices": [x.name for x in shared_items.postprocessing_scripts()], "visible": False }), - "postprocessing_sep_img2img": OptionInfo("

Img2Img & Inpainting

", "", gr.HTML), + "postprocessing_sep_img2img": OptionInfo("

Inpaint

", "", gr.HTML), "img2img_color_correction": OptionInfo(False, "Apply color correction"), "mask_apply_overlay": OptionInfo(True, "Apply mask as overlay"), "img2img_background_color": OptionInfo("#ffffff", "Image transparent color fill", gr.ColorPicker, {}), @@ -831,7 +847,7 @@ def get_default_modes(): "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for image processing", gr.Slider, {"minimum": 0.1, "maximum": 1.5, "step": 0.01, "visible": not native}), "img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01, "visible": not native}), - # "postprocessing_sep_detailer": OptionInfo("

Detailer

", "", gr.HTML), + "postprocessing_sep_detailer": OptionInfo("

Detailer

", "", gr.HTML), "detailer_model": OptionInfo("Detailer", "Detailer model", gr.Radio, lambda: {"choices": [x.name() for x in detailers], "visible": False}), "detailer_classes": OptionInfo("", "Detailer classes", gr.Textbox, { "visible": False}), "detailer_conf": OptionInfo(0.6, "Min confidence", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05, "visible": False}), @@ -843,11 +859,12 @@ def get_default_modes(): "detailer_blur": OptionInfo(10, "Item edge blur", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1, "visible": False}), "detailer_strength": OptionInfo(0.5, "Detailer strength", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": False}), "detailer_models": OptionInfo(['face-yolo8n'], "Detailer models", gr.Dropdown, lambda: {"multiselect":True, "choices": list(yolo.list), "visible": False}), - "code_former_weight": OptionInfo(0.2, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": False}), "detailer_unload": OptionInfo(False, "Move detailer model to CPU when complete"), + "detailer_augment": OptionInfo(True, "Detailer use model augment"), - "postprocessing_sep_face_restore": OptionInfo("

Face restore

", "", gr.HTML), - "face_restoration_model": OptionInfo("Face restorer", "Face restoration", gr.Radio, lambda: {"choices": ['None'] + [x.name() for x in face_restorers]}), + "postprocessing_sep_face_restore": OptionInfo("

Face Restore

", "", gr.HTML), + "face_restoration_model": OptionInfo("None", "Face restoration", gr.Radio, lambda: {"choices": ['None'] + [x.name() for x in face_restorers]}), + "code_former_weight": OptionInfo(0.2, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "postprocessing_sep_upscalers": OptionInfo("

Upscaling

", "", gr.HTML), "upscaler_unload": OptionInfo(False, "Unload upscaler after processing"), @@ -857,6 +874,7 @@ def get_default_modes(): options_templates.update(options_section(('control', "Control Options"), { "control_max_units": OptionInfo(4, "Maximum number of units", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), + "control_tiles": OptionInfo("1x1, 1x2, 1x3, 1x4, 2x1, 2x1, 2x2, 2x3, 2x4, 3x1, 3x2, 3x3, 3x4, 4x1, 4x2, 4x3, 4x4", "Tiling options"), "control_move_processor": OptionInfo(False, "Processor move to CPU after use"), "control_unload_processor": OptionInfo(False, "Processor unload after use"), })) @@ -875,6 +893,15 @@ def get_default_modes(): "deepbooru_filter_tags": OptionInfo("", "Filter out tags from deepbooru output"), })) +options_templates.update(options_section(('huggingface', "Huggingface"), { + "huggingface_sep": OptionInfo("

Huggingface

", "", gr.HTML), + "diffuser_cache_config": OptionInfo(True, "Use cached model config when available"), + "huggingface_token": OptionInfo('', 'HuggingFace token'), + "diffusers_model_load_variant": OptionInfo("default", "Preferred Model variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}), + "diffusers_vae_load_variant": OptionInfo("default", "Preferred VAE variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}), + "custom_diffusers_pipeline": OptionInfo('', 'Load custom Diffusers pipeline'), +})) + options_templates.update(options_section(('extra_networks', "Networks"), { "extra_networks_sep1": OptionInfo("

Networks UI

", "", gr.HTML), "extra_networks_show": OptionInfo(True, "UI show on startup"), @@ -891,23 +918,27 @@ def get_default_modes(): "extra_networks_model_sep": OptionInfo("

Models

", "", gr.HTML), "extra_network_reference": OptionInfo(False, "Use reference values when available", gr.Checkbox), - "extra_networks_embed_sep": OptionInfo("

Embeddings

", "", gr.HTML), - "diffusers_convert_embed": OptionInfo(False, "Auto-convert SD 1.5 embeddings to SDXL ", gr.Checkbox, {"visible": native}), - "extra_networks_styles_sep": OptionInfo("

Styles

", "", gr.HTML), - "extra_networks_styles": OptionInfo(True, "Show built-in styles"), - "extra_networks_wildcard_sep": OptionInfo("

Wildcards

", "", gr.HTML), - "wildcards_enabled": OptionInfo(True, "Enable file wildcards support"), + "extra_networks_lora_sep": OptionInfo("

LoRA

", "", gr.HTML), "extra_networks_default_multiplier": OptionInfo(1.0, "Default strength", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), - "lora_preferred_name": OptionInfo("filename", "LoRA preferred name", gr.Radio, {"choices": ["filename", "alias"]}), - "lora_add_hashes_to_infotext": OptionInfo(False, "LoRA add hash info"), + "lora_preferred_name": OptionInfo("filename", "LoRA preferred name", gr.Radio, {"choices": ["filename", "alias"], "visible": False}), + "lora_add_hashes_to_infotext": OptionInfo(False, "LoRA add hash info to metadata"), + "lora_fuse_diffusers": OptionInfo(True, "LoRA fuse directly to model"), "lora_force_diffusers": OptionInfo(False if not cmd_opts.use_openvino else True, "LoRA force loading of all models using Diffusers"), "lora_maybe_diffusers": OptionInfo(False, "LoRA force loading of specific models using Diffusers"), - "lora_fuse_diffusers": OptionInfo(False if not cmd_opts.use_openvino else True, "LoRA use fuse when possible"), "lora_apply_tags": OptionInfo(0, "LoRA auto-apply tags", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "lora_in_memory_limit": OptionInfo(0, "LoRA memory cache", gr.Slider, {"minimum": 0, "maximum": 24, "step": 1}), - "lora_quant": OptionInfo("NF4","LoRA precision in quantized models", gr.Radio, {"choices": ["NF4", "FP4"]}), - "lora_load_gpu": OptionInfo(True if not cmd_opts.lowvram else False, "Load LoRA directly to GPU"), + "lora_quant": OptionInfo("NF4","LoRA precision when quantized", gr.Radio, {"choices": ["NF4", "FP4"]}), + + "extra_networks_styles_sep": OptionInfo("

Styles

", "", gr.HTML), + "extra_networks_styles": OptionInfo(True, "Show built-in styles"), + + "extra_networks_embed_sep": OptionInfo("

Embeddings

", "", gr.HTML), + "diffusers_enable_embed": OptionInfo(True, "Enable embeddings support", gr.Checkbox, {"visible": native}), + "diffusers_convert_embed": OptionInfo(False, "Auto-convert SD15 embeddings to SDXL", gr.Checkbox, {"visible": native}), + + "extra_networks_wildcard_sep": OptionInfo("

Wildcards

", "", gr.HTML), + "wildcards_enabled": OptionInfo(True, "Enable file wildcards support"), })) options_templates.update(options_section((None, "Internal options"), { diff --git a/modules/shared_items.py b/modules/shared_items.py index 4d9b325c8..9abb64718 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -86,6 +86,14 @@ def get_pipelines(): 'Kolors': getattr(diffusers, 'KolorsPipeline', None), 'AuraFlow': getattr(diffusers, 'AuraFlowPipeline', None), 'CogView': getattr(diffusers, 'CogView3PlusPipeline', None), + 'Stable Cascade': getattr(diffusers, 'StableCascadeCombinedPipeline', None), + 'PixArt-Sigma': getattr(diffusers, 'PixArtSigmaPipeline', None), + 'HunyuanDiT': getattr(diffusers, 'HunyuanDiTPipeline', None), + 'Stable Diffusion 3': getattr(diffusers, 'StableDiffusion3Pipeline', None), + 'Stable Diffusion 3 Img2Img': getattr(diffusers, 'StableDiffusion3Img2ImgPipeline', None), + 'Lumina-Next': getattr(diffusers, 'LuminaText2ImgPipeline', None), + 'FLUX': getattr(diffusers, 'FluxPipeline', None), + 'Sana': getattr(diffusers, 'SanaPAGPipeline', None), } if hasattr(diffusers, 'OnnxStableDiffusionPipeline'): onnx_pipelines = { @@ -103,19 +111,10 @@ def get_pipelines(): pipelines.update(onnx_pipelines) # items that may rely on diffusers dev version - if hasattr(diffusers, 'StableCascadeCombinedPipeline'): - pipelines['Stable Cascade'] = getattr(diffusers, 'StableCascadeCombinedPipeline', None) - if hasattr(diffusers, 'PixArtSigmaPipeline'): - pipelines['PixArt-Sigma'] = getattr(diffusers, 'PixArtSigmaPipeline', None) - if hasattr(diffusers, 'HunyuanDiTPipeline'): - pipelines['HunyuanDiT'] = getattr(diffusers, 'HunyuanDiTPipeline', None) - if hasattr(diffusers, 'StableDiffusion3Pipeline'): - pipelines['Stable Diffusion 3'] = getattr(diffusers, 'StableDiffusion3Pipeline', None) - pipelines['Stable Diffusion 3 Img2Img'] = getattr(diffusers, 'StableDiffusion3Img2ImgPipeline', None) - if hasattr(diffusers, 'LuminaText2ImgPipeline'): - pipelines['Lumina-Next'] = getattr(diffusers, 'LuminaText2ImgPipeline', None) + """ if hasattr(diffusers, 'FluxPipeline'): pipelines['FLUX'] = getattr(diffusers, 'FluxPipeline', None) + """ for k, v in pipelines.items(): if k != 'Autodetect' and v is None: diff --git a/modules/shared_state.py b/modules/shared_state.py index 7def42b8c..a3312ec33 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -28,6 +28,9 @@ class State: oom = False debug_output = os.environ.get('SD_STATE_DEBUG', None) + def __str__(self) -> str: + return f'State: job={self.job} {self.job_no}/{self.job_count} step={self.sampling_step}/{self.sampling_steps} skipped={self.skipped} interrupted={self.interrupted} paused={self.paused} info={self.textinfo}' + def skip(self): log.debug('Requested skip') self.skipped = True @@ -135,11 +138,12 @@ def end(self, api=None): modules.devices.torch_gc() def set_current_image(self): + if self.job == 'VAE': # avoid generating preview while vae is running + return from modules.shared import opts, cmd_opts - """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" - if cmd_opts.lowvram or self.api: + if cmd_opts.lowvram or self.api or not opts.live_previews_enable or opts.show_progress_every_n_steps <= 0: return - if abs(self.sampling_step - self.current_image_sampling_step) >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps > 0: + if abs(self.sampling_step - self.current_image_sampling_step) >= opts.show_progress_every_n_steps: self.do_set_current_image() def do_set_current_image(self): diff --git a/modules/style_aligned/inversion.py b/modules/style_aligned/inversion.py new file mode 100644 index 000000000..8c91cc02a --- /dev/null +++ b/modules/style_aligned/inversion.py @@ -0,0 +1,124 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations +from typing import Callable, TYPE_CHECKING +from diffusers import StableDiffusionXLPipeline +import torch +from tqdm import tqdm +if TYPE_CHECKING: + import numpy as np + + +T = torch.Tensor +TN = T +InversionCallback = Callable[[StableDiffusionXLPipeline, int, T, dict[str, T]], dict[str, T]] + + +def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device): + # Tokenize text and get embeddings + text_inputs = tokenizer(prompt, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt') + text_input_ids = text_inputs.input_ids + + with torch.no_grad(): + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + if prompt == '': + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + return negative_prompt_embeds, negative_pooled_prompt_embeds + return prompt_embeds, pooled_prompt_embeds + + +def _encode_text_sdxl(model: StableDiffusionXLPipeline, prompt: str) -> tuple[dict[str, T], T]: + device = model._execution_device # pylint: disable=protected-access + prompt_embeds, pooled_prompt_embeds, = _get_text_embeddings(prompt, model.tokenizer, model.text_encoder, device) # pylint: disable=unused-variable + prompt_embeds_2, pooled_prompt_embeds2, = _get_text_embeddings( prompt, model.tokenizer_2, model.text_encoder_2, device) + prompt_embeds = torch.cat((prompt_embeds, prompt_embeds_2), dim=-1) + text_encoder_projection_dim = model.text_encoder_2.config.projection_dim + add_time_ids = model._get_add_time_ids((1024, 1024), (0, 0), (1024, 1024), model.text_encoder.dtype, # pylint: disable=protected-access + text_encoder_projection_dim).to(device) + added_cond_kwargs = {"text_embeds": pooled_prompt_embeds2, "time_ids": add_time_ids} + return added_cond_kwargs, prompt_embeds + + +def _encode_text_sdxl_with_negative(model: StableDiffusionXLPipeline, prompt: str) -> tuple[dict[str, T], T]: + added_cond_kwargs, prompt_embeds = _encode_text_sdxl(model, prompt) + added_cond_kwargs_uncond, prompt_embeds_uncond = _encode_text_sdxl(model, "") + prompt_embeds = torch.cat((prompt_embeds_uncond, prompt_embeds, )) + added_cond_kwargs = {"text_embeds": torch.cat((added_cond_kwargs_uncond["text_embeds"], added_cond_kwargs["text_embeds"])), + "time_ids": torch.cat((added_cond_kwargs_uncond["time_ids"], added_cond_kwargs["time_ids"])),} + return added_cond_kwargs, prompt_embeds + + +def _encode_image(model: StableDiffusionXLPipeline, image: np.ndarray) -> T: + image = torch.from_numpy(image).float() / 255. + image = (image * 2 - 1).permute(2, 0, 1).unsqueeze(0) + latent = model.vae.encode(image.to(model.vae.device, model.vae.dtype))['latent_dist'].mean * model.vae.config.scaling_factor + return latent + + +def _next_step(model: StableDiffusionXLPipeline, model_output: T, timestep: int, sample: T) -> T: + timestep, next_timestep = min(timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep + alpha_prod_t = model.scheduler.alphas_cumprod[int(timestep)] if timestep >= 0 else model.scheduler.final_alpha_cumprod + alpha_prod_t_next = model.scheduler.alphas_cumprod[int(next_timestep)] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + +def _get_noise_pred(model: StableDiffusionXLPipeline, latent: T, t: T, context: T, guidance_scale: float, added_cond_kwargs: dict[str, T]): + latents_input = torch.cat([latent] * 2) + noise_pred = model.unet(latents_input, t, encoder_hidden_states=context, added_cond_kwargs=added_cond_kwargs)["sample"] + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + # latents = next_step(model, noise_pred, t, latent) + return noise_pred + + +def _ddim_loop(model: StableDiffusionXLPipeline, z0, prompt, guidance_scale) -> T: + all_latent = [z0] + added_cond_kwargs, text_embedding = _encode_text_sdxl_with_negative(model, prompt) + latent = z0.clone().detach().to(model.text_encoder.dtype) + for i in tqdm(range(model.scheduler.num_inference_steps)): + t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1] + noise_pred = _get_noise_pred(model, latent, t, text_embedding, guidance_scale, added_cond_kwargs) + latent = _next_step(model, noise_pred, t, latent) + all_latent.append(latent) + return torch.cat(all_latent).flip(0) + + +def make_inversion_callback(zts, offset: int = 0): + + def callback_on_step_end(pipeline: StableDiffusionXLPipeline, i: int, t: T, callback_kwargs: dict[str, T]) -> dict[str, T]: # pylint: disable=unused-argument + latents = callback_kwargs['latents'] + latents[0] = zts[max(offset + 1, i + 1)].to(latents.device, latents.dtype) + return {'latents': latents} + return zts[offset], callback_on_step_end + + +@torch.no_grad() +def ddim_inversion(model: StableDiffusionXLPipeline, x0: np.ndarray, prompt: str, num_inference_steps: int, guidance_scale,) -> T: + z0 = _encode_image(model, x0) + model.scheduler.set_timesteps(num_inference_steps, device=z0.device) + zs = _ddim_loop(model, z0, prompt, guidance_scale) + return zs diff --git a/modules/style_aligned/sa_handler.py b/modules/style_aligned/sa_handler.py new file mode 100644 index 000000000..ee4b1ca79 --- /dev/null +++ b/modules/style_aligned/sa_handler.py @@ -0,0 +1,281 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from diffusers import StableDiffusionXLPipeline +from dataclasses import dataclass +import torch +import torch.nn as nn +from torch.nn import functional as nnf +from diffusers.models import attention_processor # pylint: disable=ungrouped-imports +import einops + +T = torch.Tensor + + +@dataclass(frozen=True) +class StyleAlignedArgs: + share_group_norm: bool = True + share_layer_norm: bool = True + share_attention: bool = True + adain_queries: bool = True + adain_keys: bool = True + adain_values: bool = False + full_attention_share: bool = False + shared_score_scale: float = 1. + shared_score_shift: float = 0. + only_self_level: float = 0. + + +def expand_first(feat: T, scale=1.,) -> T: + b = feat.shape[0] + feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1) + if scale == 1: + feat_style = feat_style.expand(2, b // 2, *feat.shape[1:]) + else: + feat_style = feat_style.repeat(1, b // 2, 1, 1, 1) + feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1) + return feat_style.reshape(*feat.shape) + + +def concat_first(feat: T, dim=2, scale=1.) -> T: + feat_style = expand_first(feat, scale=scale) + return torch.cat((feat, feat_style), dim=dim) + + +def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]: + feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt() + feat_mean = feat.mean(dim=-2, keepdims=True) + return feat_mean, feat_std + + +def adain(feat: T) -> T: + feat_mean, feat_std = calc_mean_std(feat) + feat_style_mean = expand_first(feat_mean) + feat_style_std = expand_first(feat_std) + feat = (feat - feat_mean) / feat_std + feat = feat * feat_style_std + feat_style_mean + return feat + + +class DefaultAttentionProcessor(nn.Module): + + def __init__(self): + super().__init__() + self.processor = attention_processor.AttnProcessor2_0() + + def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None, + attention_mask=None, **kwargs): + return self.processor(attn, hidden_states, encoder_hidden_states, attention_mask) + + +class SharedAttentionProcessor(DefaultAttentionProcessor): + + def shifted_scaled_dot_product_attention(self, attn: attention_processor.Attention, query: T, key: T, value: T) -> T: + logits = torch.einsum('bhqd,bhkd->bhqk', query, key) * attn.scale + logits[:, :, :, query.shape[2]:] += self.shared_score_shift + probs = logits.softmax(-1) + return torch.einsum('bhqk,bhkd->bhqd', probs, value) + + def shared_call( # pylint: disable=unused-argument + self, + attn: attention_processor.Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + **kwargs + ): + + residual = hidden_states + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # if self.step >= self.start_inject: + if self.adain_queries: + query = adain(query) + if self.adain_keys: + key = adain(key) + if self.adain_values: + value = adain(value) + if self.share_attention: + key = concat_first(key, -2, scale=self.shared_score_scale) + value = concat_first(value, -2) + if self.shared_score_shift != 0: + hidden_states = self.shifted_scaled_dot_product_attention(attn, query, key, value,) + else: + hidden_states = nnf.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + else: + hidden_states = nnf.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + # hidden_states = adain(hidden_states) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None, + attention_mask=None, **kwargs): + if self.full_attention_share: + _b, n, _d = hidden_states.shape + hidden_states = einops.rearrange(hidden_states, '(k b) n d -> k (b n) d', k=2) + hidden_states = super().__call__(attn, hidden_states, encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, **kwargs) + hidden_states = einops.rearrange(hidden_states, 'k (b n) d -> (k b) n d', n=n) + else: + hidden_states = self.shared_call(attn, hidden_states, hidden_states, attention_mask, **kwargs) + + return hidden_states + + def __init__(self, style_aligned_args: StyleAlignedArgs): + super().__init__() + self.share_attention = style_aligned_args.share_attention + self.adain_queries = style_aligned_args.adain_queries + self.adain_keys = style_aligned_args.adain_keys + self.adain_values = style_aligned_args.adain_values + self.full_attention_share = style_aligned_args.full_attention_share + self.shared_score_scale = style_aligned_args.shared_score_scale + self.shared_score_shift = style_aligned_args.shared_score_shift + + +def _get_switch_vec(total_num_layers, level): + if level <= 0: + return torch.zeros(total_num_layers, dtype=torch.bool) + if level >= 1: + return torch.ones(total_num_layers, dtype=torch.bool) + to_flip = level > .5 + if to_flip: + level = 1 - level + num_switch = int(level * total_num_layers) + vec = torch.arange(total_num_layers) + vec = vec % (total_num_layers // num_switch) + vec = vec == 0 + if to_flip: + vec = ~vec + return vec + + +def init_attention_processors(pipeline: StableDiffusionXLPipeline, style_aligned_args: StyleAlignedArgs | None = None): + attn_procs = {} + unet = pipeline.unet + number_of_self, number_of_cross = 0, 0 + num_self_layers = len([name for name in unet.attn_processors.keys() if 'attn1' in name]) + if style_aligned_args is None: + only_self_vec = _get_switch_vec(num_self_layers, 1) + else: + only_self_vec = _get_switch_vec(num_self_layers, style_aligned_args.only_self_level) + for i, name in enumerate(unet.attn_processors.keys()): + is_self_attention = 'attn1' in name + if is_self_attention: + number_of_self += 1 + if style_aligned_args is None or only_self_vec[i // 2]: + attn_procs[name] = DefaultAttentionProcessor() + else: + attn_procs[name] = SharedAttentionProcessor(style_aligned_args) + else: + number_of_cross += 1 + attn_procs[name] = DefaultAttentionProcessor() + + unet.set_attn_processor(attn_procs) + + +def register_shared_norm(pipeline: StableDiffusionXLPipeline, + share_group_norm: bool = True, + share_layer_norm: bool = True, + ): + def register_norm_forward(norm_layer: nn.GroupNorm | nn.LayerNorm) -> nn.GroupNorm | nn.LayerNorm: + if not hasattr(norm_layer, 'orig_forward'): + setattr(norm_layer, 'orig_forward', norm_layer.forward) # noqa + orig_forward = norm_layer.orig_forward + + def forward_(hidden_states: T) -> T: + n = hidden_states.shape[-2] + hidden_states = concat_first(hidden_states, dim=-2) + hidden_states = orig_forward(hidden_states) + return hidden_states[..., :n, :] + + norm_layer.forward = forward_ + return norm_layer + + def get_norm_layers(pipeline_, norm_layers_: dict[str, list[nn.GroupNorm | nn.LayerNorm]]): + if isinstance(pipeline_, nn.LayerNorm) and share_layer_norm: + norm_layers_['layer'].append(pipeline_) + if isinstance(pipeline_, nn.GroupNorm) and share_group_norm: + norm_layers_['group'].append(pipeline_) + else: + for layer in pipeline_.children(): + get_norm_layers(layer, norm_layers_) + + norm_layers = {'group': [], 'layer': []} + get_norm_layers(pipeline.unet, norm_layers) + return [register_norm_forward(layer) for layer in norm_layers['group']] + [register_norm_forward(layer) for layer in + norm_layers['layer']] + + +class Handler: + + def register(self, style_aligned_args: StyleAlignedArgs): + self.norm_layers = register_shared_norm(self.pipeline, style_aligned_args.share_group_norm, + style_aligned_args.share_layer_norm) + init_attention_processors(self.pipeline, style_aligned_args) + + def remove(self): + for layer in self.norm_layers: + layer.forward = layer.orig_forward + self.norm_layers = [] + init_attention_processors(self.pipeline, None) + + def __init__(self, pipeline: StableDiffusionXLPipeline): + self.pipeline = pipeline + self.norm_layers = [] diff --git a/modules/styles.py b/modules/styles.py index 0dd48eb7f..d0228d33a 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -52,7 +52,7 @@ def check_files(prompt, wildcard, files): choice = random.choice(lines).strip(' \n') if '|' in choice: choice = random.choice(choice.split('|')).strip(' []{}\n') - prompt = prompt.replace(f"__{wildcard}__", choice) + prompt = prompt.replace(f"__{wildcard}__", choice, 1) shared.log.debug(f'Wildcards apply: wildcard="{wildcard}" choice="{choice}" file="{file}" choices={len(lines)}') replaced.append(wildcard) return prompt, True diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index dd4203a4f..de12021a5 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -4,13 +4,14 @@ import torch import safetensors.torch from PIL import Image -from modules import shared, devices, sd_models, errors +from modules import shared, devices, errors from modules.textual_inversion.image_embedding import embedding_from_b64, extract_image_data_embed from modules.files_cache import directory_files, directory_mtime, extension_filter debug = shared.log.trace if os.environ.get('SD_TI_DEBUG', None) is not None else lambda *args, **kwargs: None debug('Trace: TEXTUAL INVERSION') +supported_models = ['ldm', 'sd', 'sdxl'] def list_embeddings(*dirs): @@ -274,6 +275,8 @@ def load_diffusers_embedding(self, filename: Union[str, List[str]] = None, data: overwrite = bool(data) if not shared.sd_loaded: return + if not shared.opts.diffusers_enable_embed: + return embeddings, skipped = open_embeddings(filename) or convert_bundled(data) for skip in skipped: self.skipped_embeddings[skip.name] = skipped @@ -368,7 +371,7 @@ def load_from_file(self, path, filename): self.skipped_embeddings[name] = embedding def load_from_dir(self, embdir): - if sd_models.model_data.sd_model is None: + if not shared.sd_loaded: shared.log.info('Skipping embeddings load: model not loaded') return if not os.path.isdir(embdir.path): @@ -388,6 +391,8 @@ def load_from_dir(self, embdir): def load_textual_inversion_embeddings(self, force_reload=False): if not shared.sd_loaded: return + if shared.sd_model_type not in supported_models: + return t0 = time.time() if not force_reload: need_reload = False diff --git a/modules/timer.py b/modules/timer.py index 8a5db726d..43e859140 100644 --- a/modules/timer.py +++ b/modules/timer.py @@ -7,6 +7,7 @@ def __init__(self): self.start = time.time() self.records = {} self.total = 0 + self.profile = False def elapsed(self, reset=True): end = time.time() @@ -15,17 +16,24 @@ def elapsed(self, reset=True): self.start = end return res + def add(self, name, t): + if name not in self.records: + self.records[name] = t + else: + self.records[name] += t + def record(self, category=None, extra_time=0, reset=True): e = self.elapsed(reset) if category is None: category = sys._getframe(1).f_code.co_name # pylint: disable=protected-access if category not in self.records: self.records[category] = 0 - self.records[category] += e + extra_time self.total += e + extra_time def summary(self, min_time=0.05, total=True): + if self.profile: + min_time = -1 res = f"{self.total:.2f} " if total else '' additions = [x for x in self.records.items() if x[1] >= min_time] if not additions: @@ -34,6 +42,8 @@ def summary(self, min_time=0.05, total=True): return res def dct(self, min_time=0.05): + if self.profile: + return {k: round(v, 4) for k, v in self.records.items()} return {k: round(v, 2) for k, v in self.records.items() if v >= min_time} def reset(self): diff --git a/modules/txt2img.py b/modules/txt2img.py index 2f0e2f4b3..e82c744a2 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -88,7 +88,7 @@ def txt2img(id_task, state, p.scripts = scripts.scripts_txt2img p.script_args = args p.state = state - processed = scripts.scripts_txt2img.run(p, *args) + processed: processing.Processed = scripts.scripts_txt2img.run(p, *args) if processed is None: processed = processing.process_images(p) processed = scripts.scripts_txt2img.after(p, processed, *args) diff --git a/modules/ui_common.py b/modules/ui_common.py index 9c4bb5cdc..3e7c68bec 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -245,10 +245,18 @@ def create_output_panel(tabname, preview=True, prompt=None, height=None): gr.HTML(value="", elem_id="main_info", visible=False, elem_classes=["main-info"]) # columns are for <576px, <768px, <992px, <1200px, <1400px, >1400px result_gallery = gr.Gallery(value=[], - label='Output', show_label=False, show_download_button=True, allow_preview=True, container=False, preview=preview, - columns=4, object_fit='scale-down', height=height, + label='Output', + show_label=False, + show_download_button=True, + allow_preview=True, + container=False, + preview=preview, + columns=4, + object_fit='scale-down', + height=height, elem_id=f"{tabname}_gallery", - ) + elem_classes=["gallery_main"], + ) if prompt is not None: interrogate_clip_btn, interrogate_booru_btn = ui_sections.create_interrogate_buttons('control') interrogate_clip_btn.click(fn=interrogate_clip, inputs=[result_gallery], outputs=[prompt]) diff --git a/modules/ui_control.py b/modules/ui_control.py index 0bf070036..cd78313a7 100644 --- a/modules/ui_control.py +++ b/modules/ui_control.py @@ -9,11 +9,11 @@ from modules.control.units import lite # vislearn ControlNet-XS from modules.control.units import t2iadapter # TencentARC T2I-Adapter from modules.control.units import reference # reference pipeline -from modules import errors, shared, progress, ui_components, ui_symbols, ui_common, ui_sections, generation_parameters_copypaste, call_queue, scripts, masking, images, processing_vae # pylint: disable=ungrouped-imports +from modules import errors, shared, progress, ui_components, ui_symbols, ui_common, ui_sections, generation_parameters_copypaste, call_queue, scripts, masking, images, processing_vae, timer # pylint: disable=ungrouped-imports from modules import ui_control_helpers as helpers -gr_height = None +gr_height = 512 max_units = shared.opts.control_max_units units: list[unit.Unit] = [] # main state variable controls: list[gr.component] = [] # list of gr controls @@ -21,13 +21,41 @@ debug('Trace: CONTROL') -def return_controls(res): +def return_stats(t: float = None): + if t is None: + elapsed_text = '' + else: + elapsed = time.perf_counter() - t + elapsed_m = int(elapsed // 60) + elapsed_s = elapsed % 60 + elapsed_text = f"Time: {elapsed_m}m {elapsed_s:.2f}s |" if elapsed_m > 0 else f"Time: {elapsed_s:.2f}s |" + summary = timer.process.summary(min_time=0.25, total=False).replace('=', ' ') + gpu = '' + cpu = '' + if not shared.mem_mon.disabled: + vram = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.read().items()} + peak = max(vram['active_peak'], vram['reserved_peak'], vram['used']) + used = round(100.0 * peak / vram['total']) if vram['total'] > 0 else 0 + if used > 0: + gpu += f"| GPU {peak} MB {used}%" + gpu += f" | retries {vram['retries']} oom {vram['oom']}" if vram.get('retries', 0) > 0 or vram.get('oom', 0) > 0 else '' + ram = shared.ram_stats() + if ram['used'] > 0: + cpu += f"| RAM {ram['used']} GB {round(100.0 * ram['used'] / ram['total'])}%" + return f"

Time: {elapsed_text} | {summary} {gpu} {cpu}

" + + +def return_controls(res, t: float = None): # return preview, image, video, gallery, text debug(f'Control received: type={type(res)} {res}') + if t is None: + perf = '' + else: + perf = return_stats(t) if res is None: # no response - return [None, None, None, None, ''] + return [None, None, None, None, '', perf] elif isinstance(res, str): # error response - return [None, None, None, None, res] + return [None, None, None, None, res, perf] elif isinstance(res, tuple): # standard response received as tuple via control_run->yield(output_images, process_image, result_txt) preview_image = res[1] # may be None output_image = res[0][0] if isinstance(res[0], list) else res[0] # may be image or list of images @@ -37,9 +65,9 @@ def return_controls(res): output_gallery = [res[0]] if res[0] is not None else [] # must return list, but can receive single image result_txt = res[2] if len(res) > 2 else '' # do we have a message output_video = res[3] if len(res) > 3 else None # do we have a video filename - return [preview_image, output_image, output_video, output_gallery, result_txt] + return [preview_image, output_image, output_video, output_gallery, result_txt, perf] else: # unexpected - return [None, None, None, None, f'Control: Unexpected response: {type(res)}'] + return [None, None, None, None, f'Control: Unexpected response: {type(res)}', perf] def get_units(*values): @@ -67,17 +95,18 @@ def generate_click(job_id: str, state: str, active_tab: str, *args): shared.state.begin('Generate') progress.add_task_to_queue(job_id) with call_queue.queue_lock: - yield [None, None, None, None, 'Control: starting'] + yield [None, None, None, None, 'Control: starting', ''] shared.mem_mon.reset() progress.start_task(job_id) try: + t = time.perf_counter() for results in control_run(state, units, helpers.input_source, helpers.input_init, helpers.input_mask, active_tab, True, *args): progress.record_results(job_id, results) - yield return_controls(results) + yield return_controls(results, t) except Exception as e: shared.log.error(f"Control exception: {e}") errors.display(e, 'Control') - yield [None, None, None, None, f'Control: Exception: {e}'] + yield [None, None, None, None, f'Control: Exception: {e}', ''] progress.finish_task(job_id) shared.state.end() @@ -106,11 +135,12 @@ def create_ui(_blocks: gr.Blocks=None): with gr.Accordion(open=False, label="Input", elem_id="control_input", elem_classes=["small-accordion"]): with gr.Row(): - show_preview = gr.Checkbox(label="Show preview", value=True, elem_id="control_show_preview") + show_input = gr.Checkbox(label="Show input", value=True, elem_id="control_show_input") + show_preview = gr.Checkbox(label="Show preview", value=False, elem_id="control_show_preview") with gr.Row(): - input_type = gr.Radio(label="Input type", choices=['Control only', 'Init image same as control', 'Separate init image'], value='Control only', type='index', elem_id='control_input_type') + input_type = gr.Radio(label="Control input type", choices=['Control only', 'Init image same as control', 'Separate init image'], value='Control only', type='index', elem_id='control_input_type') with gr.Row(): - denoising_strength = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Denoising strength', value=0.50, elem_id="control_input_denoising_strength") + denoising_strength = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Denoising strength', value=0.30, elem_id="control_input_denoising_strength") with gr.Accordion(open=False, label="Size", elem_id="control_size", elem_classes=["small-accordion"]): with gr.Tabs(): @@ -153,13 +183,13 @@ def create_ui(_blocks: gr.Blocks=None): override_settings = ui_common.create_override_inputs('control') with gr.Row(variant='compact', elem_id="control_extra_networks", elem_classes=["extra_networks_root"], visible=False) as extra_networks_ui: - from modules import timer, ui_extra_networks + from modules import ui_extra_networks extra_networks_ui = ui_extra_networks.create_ui(extra_networks_ui, btn_extra, 'control', skip_indexing=shared.opts.extra_network_skip_indexing) timer.startup.record('ui-networks') with gr.Row(elem_id='control-inputs'): - with gr.Column(scale=9, elem_id='control-input-column', visible=True) as _column_input: - gr.HTML('Control input

') + with gr.Column(scale=9, elem_id='control-input-column', visible=True) as column_input: + gr.HTML('Input

') with gr.Tabs(elem_classes=['control-tabs'], elem_id='control-tab-input'): with gr.Tab('Image', id='in-image') as tab_image: input_mode = gr.Label(value='select', visible=False) @@ -190,12 +220,12 @@ def create_ui(_blocks: gr.Blocks=None): gr.HTML('Output

') with gr.Tabs(elem_classes=['control-tabs'], elem_id='control-tab-output') as output_tabs: with gr.Tab('Gallery', id='out-gallery'): - output_gallery, _output_gen_info, _output_html_info, _output_html_info_formatted, _output_html_log = ui_common.create_output_panel("control", preview=True, prompt=prompt, height=gr_height) + output_gallery, _output_gen_info, _output_html_info, _output_html_info_formatted, output_html_log = ui_common.create_output_panel("control", preview=True, prompt=prompt, height=gr_height) with gr.Tab('Image', id='out-image'): output_image = gr.Image(label="Output", show_label=False, type="pil", interactive=False, tool="editor", height=gr_height, elem_id='control_output_image', elem_classes=['control-image']) with gr.Tab('Video', id='out-video'): output_video = gr.Video(label="Output", show_label=False, height=gr_height, elem_id='control_output_video', elem_classes=['control-image']) - with gr.Column(scale=9, elem_id='control-preview-column', visible=True) as column_preview: + with gr.Column(scale=9, elem_id='control-preview-column', visible=False) as column_preview: gr.HTML('Preview

') with gr.Tabs(elem_classes=['control-tabs'], elem_id='control-tab-preview'): with gr.Tab('Preview', id='preview-image') as _tab_preview: @@ -220,10 +250,11 @@ def create_ui(_blocks: gr.Blocks=None): process_id = gr.Dropdown(label="Processor", choices=processors.list_models(), value='None', elem_id=f'control_unit-{i}-process_name') model_id = gr.Dropdown(label="ControlNet", choices=controlnet.list_models(), value='None', elem_id=f'control_unit-{i}-model_name') ui_common.create_refresh_button(model_id, controlnet.list_models, lambda: {"choices": controlnet.list_models(refresh=True)}, f'refresh_controlnet_models_{i}') + control_mode = gr.Dropdown(label="CN Mode", choices=['default'], value='default', visible=False, elem_id=f'control_unit-{i}-mode') model_strength = gr.Slider(label="CN Strength", minimum=0.01, maximum=2.0, step=0.01, value=1.0, elem_id=f'control_unit-{i}-strength') - control_start = gr.Slider(label="Start", minimum=0.0, maximum=1.0, step=0.05, value=0, elem_id=f'control_unit-{i}-start') - control_end = gr.Slider(label="End", minimum=0.0, maximum=1.0, step=0.05, value=1.0, elem_id=f'control_unit-{i}-end') - control_mode = gr.Dropdown(label="CN Mode", choices=['', 'Canny', 'Tile', 'Depth', 'Blur', 'Pose', 'Gray', 'LQ'], value=0, type='index', visible=False, elem_id=f'control_unit-{i}-mode') + control_start = gr.Slider(label="CN Start", minimum=0.0, maximum=1.0, step=0.05, value=0, elem_id=f'control_unit-{i}-start') + control_end = gr.Slider(label="CN End", minimum=0.0, maximum=1.0, step=0.05, value=1.0, elem_id=f'control_unit-{i}-end') + control_tile = gr.Dropdown(label="CN Tiles", choices=[x.strip() for x in shared.opts.control_tiles.split(',') if 'x' in x], value='1x1', visible=False, elem_id=f'control_unit-{i}-tile') reset_btn = ui_components.ToolButton(value=ui_symbols.reset) image_upload = gr.UploadButton(label=ui_symbols.upload, file_types=['image'], elem_classes=['form', 'gradio-button', 'tool']) image_reuse= ui_components.ToolButton(value=ui_symbols.reuse) @@ -248,6 +279,7 @@ def create_ui(_blocks: gr.Blocks=None): control_start = control_start, control_end = control_end, control_mode = control_mode, + control_tile = control_tile, extra_controls = extra_controls, ) ) @@ -498,6 +530,7 @@ def create_ui(_blocks: gr.Blocks=None): btn_update = gr.Button('Update', interactive=True, visible=False, elem_id='control_update') btn_update.click(fn=get_units, inputs=controls, outputs=[], show_progress=True, queue=False) + show_input.change(fn=lambda x: gr.update(visible=x), inputs=[show_input], outputs=[column_input]) show_preview.change(fn=lambda x: gr.update(visible=x), inputs=[show_preview], outputs=[column_preview]) input_type.change(fn=lambda x: gr.update(visible=x == 2), inputs=[input_type], outputs=[column_init]) btn_prompt_counter.click(fn=call_queue.wrap_queued_call(ui_common.update_token_counter), inputs=[prompt, steps], outputs=[prompt_counter]) @@ -550,6 +583,7 @@ def create_ui(_blocks: gr.Blocks=None): output_video, output_gallery, result_txt, + output_html_log, ] control_dict = dict( fn=generate_click, diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index f6e6cee97..94664c5cb 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -16,7 +16,7 @@ import gradio as gr from PIL import Image from starlette.responses import FileResponse, JSONResponse -from modules import paths, shared, scripts, files_cache, errors, infotext +from modules import paths, shared, files_cache, errors, infotext from modules.ui_components import ToolButton import modules.ui_symbols as symbols @@ -135,6 +135,7 @@ def patch(self, text: str, tabname: str): return text.replace('~tabname', tabname) def create_xyz_grid(self): + """ xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module def add_prompt(p, opt, x): @@ -150,6 +151,7 @@ def add_prompt(p, opt, x): opt = xyz_grid.AxisOption(f"[Network] {self.title}", str, add_prompt, choices=lambda: [x["name"] for x in self.items]) if opt not in xyz_grid.axis_options: xyz_grid.axis_options.append(opt) + """ def link_preview(self, filename): quoted_filename = urllib.parse.quote(filename.replace('\\', '/')) @@ -458,17 +460,20 @@ def register_page(page: ExtraNetworksPage): def register_pages(): - from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion + debug('EN register-pages') from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints - from modules.ui_extra_networks_styles import ExtraNetworksPageStyles from modules.ui_extra_networks_vae import ExtraNetworksPageVAEs + from modules.ui_extra_networks_styles import ExtraNetworksPageStyles from modules.ui_extra_networks_history import ExtraNetworksPageHistory - debug('EN register-pages') + from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion register_page(ExtraNetworksPageCheckpoints()) - register_page(ExtraNetworksPageStyles()) - register_page(ExtraNetworksPageTextualInversion()) register_page(ExtraNetworksPageVAEs()) + register_page(ExtraNetworksPageStyles()) register_page(ExtraNetworksPageHistory()) + register_page(ExtraNetworksPageTextualInversion()) + if shared.native: + from modules.ui_extra_networks_lora import ExtraNetworksPageLora + register_page(ExtraNetworksPageLora()) if shared.opts.hypernetwork_enabled: from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks register_page(ExtraNetworksPageHypernetworks()) diff --git a/modules/ui_extra_networks_lora.py b/modules/ui_extra_networks_lora.py new file mode 100644 index 000000000..9dd1b3573 --- /dev/null +++ b/modules/ui_extra_networks_lora.py @@ -0,0 +1,123 @@ +import os +import json +import concurrent +import modules.lora.networks as networks +from modules import shared, ui_extra_networks + + +debug = os.environ.get('SD_LORA_DEBUG', None) is not None + + +class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Lora') + self.list_time = 0 + + def refresh(self): + networks.list_available_networks() + + @staticmethod + def get_tags(l, info): + tags = {} + try: + if l.metadata is not None: + modelspec_tags = l.metadata.get('modelspec.tags', {}) + possible_tags = l.metadata.get('ss_tag_frequency', {}) # tags from model metedata + if isinstance(possible_tags, str): + possible_tags = {} + if isinstance(modelspec_tags, str): + modelspec_tags = {} + if len(list(modelspec_tags)) > 0: + possible_tags.update(modelspec_tags) + for k, v in possible_tags.items(): + words = k.split('_', 1) if '_' in k else [v, k] + words = [str(w).replace('.json', '') for w in words] + if words[0] == '{}': + words[0] = 0 + tag = ' '.join(words[1:]).lower() + tags[tag] = words[0] + + def find_version(): + found_versions = [] + current_hash = l.hash[:8].upper() + all_versions = info.get('modelVersions', []) + for v in info.get('modelVersions', []): + for f in v.get('files', []): + if any(h.startswith(current_hash) for h in f.get('hashes', {}).values()): + found_versions.append(v) + if len(found_versions) == 0: + found_versions = all_versions + return found_versions + + for v in find_version(): # trigger words from info json + possible_tags = v.get('trainedWords', []) + if isinstance(possible_tags, list): + for tag_str in possible_tags: + for tag in tag_str.split(','): + tag = tag.strip().lower() + if tag not in tags: + tags[tag] = 0 + + possible_tags = info.get('tags', []) # tags from info json + if not isinstance(possible_tags, list): + possible_tags = list(possible_tags.values()) + for tag in possible_tags: + tag = tag.strip().lower() + if tag not in tags: + tags[tag] = 0 + except Exception: + pass + bad_chars = [';', ':', '<', ">", "*", '?', '\'', '\"', '(', ')', '[', ']', '{', '}', '\\', '/'] + clean_tags = {} + for k, v in tags.items(): + tag = ''.join(i for i in k if i not in bad_chars).strip() + clean_tags[tag] = v + + clean_tags.pop('img', None) + clean_tags.pop('dataset', None) + return clean_tags + + def create_item(self, name): + l = networks.available_networks.get(name) + if l is None: + shared.log.warning(f'Networks: type=lora registered={len(list(networks.available_networks))} file="{name}" not registered') + return None + try: + # path, _ext = os.path.splitext(l.filename) + name = os.path.splitext(os.path.relpath(l.filename, shared.cmd_opts.lora_dir))[0] + item = { + "type": 'Lora', + "name": name, + "filename": l.filename, + "hash": l.shorthash, + "prompt": json.dumps(f" "), + "metadata": json.dumps(l.metadata, indent=4) if l.metadata else None, + "mtime": os.path.getmtime(l.filename), + "size": os.path.getsize(l.filename), + "version": l.sd_version, + } + info = self.find_info(l.filename) + item["info"] = info + item["description"] = self.find_description(l.filename, info) # use existing info instead of double-read + item["tags"] = self.get_tags(l, info) + return item + except Exception as e: + shared.log.error(f'Networks: type=lora file="{name}" {e}') + if debug: + from modules import errors + errors.display(e, 'Lora') + return None + + def list_items(self): + items = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor: + future_items = {executor.submit(self.create_item, net): net for net in networks.available_networks} + for future in concurrent.futures.as_completed(future_items): + item = future.result() + if item is not None: + items.append(item) + self.update_all_previews(items) + return items + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.lora_dir] diff --git a/modules/ui_img2img.py b/modules/ui_img2img.py index 22c89dac8..45c901c6d 100644 --- a/modules/ui_img2img.py +++ b/modules/ui_img2img.py @@ -1,7 +1,6 @@ import os from PIL import Image import gradio as gr -import numpy as np from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call from modules import timer, shared, ui_common, ui_sections, generation_parameters_copypaste, processing_vae @@ -56,7 +55,7 @@ def copy_image(img): def add_copy_image_controls(tab_name, elem): with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"): - for title, name in zip(['➠ Image', '➠ Sketch', '➠ Inpaint', '➠ Composite'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']): + for title, name in zip(['➠ Image', '➠ Inpaint', '➠ Sketch', '➠ Composite'], ['img2img', 'sketch', 'inpaint', 'composite']): if name == tab_name: gr.Button(title, elem_id=f'copy_to_{name}', interactive=False) copy_image_destinations[name] = elem @@ -67,45 +66,47 @@ def add_copy_image_controls(tab_name, elem): with gr.Tabs(elem_id="mode_img2img"): img2img_selected_tab = gr.State(0) # pylint: disable=abstract-class-instantiated state = gr.Textbox(value='', visible=False) - with gr.TabItem('Image', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=512) + with gr.TabItem('Image', id='img2img_image', elem_id="img2img_image_tab") as tab_img2img: + img_init = gr.Image(label="", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=512) interrogate_clip, interrogate_booru = ui_sections.create_interrogate_buttons('img2img') - add_copy_image_controls('img2img', init_img) - - with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: - sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=512) - add_copy_image_controls('sketch', sketch) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=512) - add_copy_image_controls('inpaint', init_img_with_mask) - - with gr.TabItem('Composite', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: - inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=512) - inpaint_color_sketch_orig = gr.State(None) # pylint: disable=abstract-class-instantiated - add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) - - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - return state - - inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) + add_copy_image_controls('img2img', img_init) + + with gr.TabItem('Inpaint', id='img2img_inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + img_inpaint = gr.Image(label="", elem_id="img2img_inpaint", show_label=False, source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=512) + add_copy_image_controls('inpaint', img_inpaint) + + with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_sketch_tab") as tab_sketch: + img_sketch = gr.Image(label="", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=512) + add_copy_image_controls('sketch', img_sketch) + + with gr.TabItem('Composite', id='img2img_composite', elem_id="img2img_composite_tab") as tab_inpaint_color: + img_composite = gr.Image(label="", show_label=False, elem_id="img2img_composite", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=512) + img_composite_orig = gr.State(None) # pylint: disable=abstract-class-instantiated + img_composite_orig_update = False + + def fn_img_composite_upload(): + nonlocal img_composite_orig_update + img_composite_orig_update = True + def fn_img_composite_change(img, img_composite): + nonlocal img_composite_orig_update + res = img if img_composite_orig_update else img_composite + img_composite_orig_update = False + return res + + img_composite.upload(fn=fn_img_composite_upload, inputs=[], outputs=[]) + img_composite.change(fn=fn_img_composite_change, inputs=[img_composite, img_composite_orig], outputs=[img_composite_orig]) + add_copy_image_controls('composite', img_composite) with gr.TabItem('Upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask") with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Upload images or process images in a directory
Add inpaint batch mask directory to enable inpaint batch processing {hidden}

") + gr.HTML("

Run image processing on upload images or files in a folder
If masks are provided will run inpaint

") img2img_batch_files = gr.Files(label="Batch Process", interactive=True, elem_id="img2img_image_batch") - img2img_batch_input_dir = gr.Textbox(label="Inpaint batch input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Inpaint batch output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + img2img_batch_input_dir = gr.Textbox(label="Batch input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Batch output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Batch mask directory", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] for i, tab in enumerate(img2img_tabs): @@ -120,13 +121,13 @@ def update_orig(image, state): with gr.Accordion(open=False, label="Sampler", elem_classes=["small-accordion"], elem_id="img2img_sampler_group"): steps, sampler_index = ui_sections.create_sampler_and_steps_selection(None, "img2img") ui_sections.create_sampler_options('img2img') - resize_mode, resize_name, resize_context, width, height, scale_by, selected_scale_tab = ui_sections.create_resize_inputs('img2img', [init_img, sketch], latent=True, non_zero=False) + resize_mode, resize_name, resize_context, width, height, scale_by, selected_scale_tab = ui_sections.create_resize_inputs('img2img', [img_init, img_sketch], latent=True, non_zero=False) batch_count, batch_size = ui_sections.create_batch_inputs('img2img', accordion=True) seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = ui_sections.create_seed_inputs('img2img') with gr.Accordion(open=False, label="Denoise", elem_classes=["small-accordion"], elem_id="img2img_denoise_group"): with gr.Row(): - denoising_strength = gr.Slider(minimum=0.0, maximum=0.99, step=0.01, label='Denoising strength', value=0.50, elem_id="img2img_denoising_strength") + denoising_strength = gr.Slider(minimum=0.0, maximum=0.99, step=0.01, label='Denoising strength', value=0.30, elem_id="img2img_denoising_strength") refiner_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Denoise start', value=0.0, elem_id="img2img_refiner_start") full_quality, tiling, hidiffusion, cfg_scale, clip_skip, image_cfg_scale, diffusers_guidance_rescale, pag_scale, pag_adaptive, cfg_end = ui_sections.create_advanced_inputs('img2img') @@ -167,13 +168,8 @@ def select_img2img_tab(tab): img2img_args = [ dummy_component1, state, dummy_component2, img2img_prompt, img2img_negative_prompt, img2img_prompt_styles, - init_img, - sketch, - init_img_with_mask, - inpaint_color_sketch, - inpaint_color_sketch_orig, - init_img_inpaint, - init_mask_inpaint, + img_init, img_sketch, img_inpaint, img_composite, img_composite_orig, + init_img_inpaint, init_mask_inpaint, steps, sampler_index, mask_blur, mask_alpha, @@ -225,10 +221,7 @@ def select_img2img_tab(tab): img2img_batch_files, img2img_batch_input_dir, img2img_batch_output_dir, - init_img, - sketch, - init_img_with_mask, - inpaint_color_sketch, + img_init, img_sketch, img_inpaint, img_composite, init_img_inpaint, ], outputs=[img2img_prompt, dummy_component], @@ -285,7 +278,8 @@ def select_img2img_tab(tab): (seed_resize_from_h, "Seed resize from-2"), *modules.scripts.scripts_img2img.infotext_fields ] - generation_parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) - generation_parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) + generation_parameters_copypaste.add_paste_fields("img2img", img_init, img2img_paste_fields, override_settings) + generation_parameters_copypaste.add_paste_fields("sketch", img_sketch, img2img_paste_fields, override_settings) + generation_parameters_copypaste.add_paste_fields("inpaint", img_inpaint, img2img_paste_fields, override_settings) img2img_bindings = generation_parameters_copypaste.ParamBinding(paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None) generation_parameters_copypaste.register_paste_params_button(img2img_bindings) diff --git a/modules/ui_models.py b/modules/ui_models.py index 624c3849d..7ab8b0d07 100644 --- a/modules/ui_models.py +++ b/modules/ui_models.py @@ -8,7 +8,7 @@ from modules.ui_components import ToolButton from modules.ui_common import create_refresh_button from modules.call_queue import wrap_gradio_gpu_call -from modules.shared import opts, log, req, readfile, max_workers +from modules.shared import opts, log, req, readfile, max_workers, native import modules.ui_symbols import modules.errors import modules.hashes @@ -794,6 +794,10 @@ def civit_update_download(): civit_results4.select(fn=civit_update_select, inputs=[civit_results4], outputs=[models_outcome, civit_update_download_btn]) civit_update_download_btn.click(fn=civit_update_download, inputs=[], outputs=[models_outcome]) + if native: + from modules.lora.lora_extract import create_ui as lora_extract_ui + lora_extract_ui() + for ui in extra_ui: if callable(ui): ui() diff --git a/modules/ui_sections.py b/modules/ui_sections.py index f15edb4bd..fcf53cf70 100644 --- a/modules/ui_sections.py +++ b/modules/ui_sections.py @@ -276,11 +276,11 @@ def set_sampler_preset(preset): else: # shared.native with gr.Row(elem_classes=['flex-break']): - sampler_sigma = gr.Dropdown(label='Sigma method', elem_id=f"{tabname}_sampler_sigma", choices=['default', 'karras', 'betas', 'exponential', 'lambdas'], value=shared.opts.schedulers_sigma, type='value') + sampler_sigma = gr.Dropdown(label='Sigma method', elem_id=f"{tabname}_sampler_sigma", choices=['default', 'karras', 'betas', 'exponential', 'lambdas', 'flowmatch'], value=shared.opts.schedulers_sigma, type='value') sampler_spacing = gr.Dropdown(label='Timestep spacing', elem_id=f"{tabname}_sampler_spacing", choices=['default', 'linspace', 'leading', 'trailing'], value=shared.opts.schedulers_timestep_spacing, type='value') with gr.Row(elem_classes=['flex-break']): sampler_beta = gr.Dropdown(label='Beta schedule', elem_id=f"{tabname}_sampler_beta", choices=['default', 'linear', 'scaled', 'cosine'], value=shared.opts.schedulers_beta_schedule, type='value') - sampler_prediction = gr.Dropdown(label='Prediction method', elem_id=f"{tabname}_sampler_prediction", choices=['default', 'epsilon', 'sample', 'v_prediction'], value=shared.opts.schedulers_prediction_type, type='value') + sampler_prediction = gr.Dropdown(label='Prediction method', elem_id=f"{tabname}_sampler_prediction", choices=['default', 'epsilon', 'sample', 'v_prediction', 'flow_prediction'], value=shared.opts.schedulers_prediction_type, type='value') with gr.Row(elem_classes=['flex-break']): sampler_presets = gr.Dropdown(label='Timesteps presets', elem_id=f"{tabname}_sampler_presets", choices=['None', 'AYS SD15', 'AYS SDXL'], value='None', type='value') sampler_timesteps = gr.Textbox(label='Timesteps override', elem_id=f"{tabname}_sampler_timesteps", value=shared.opts.schedulers_timesteps) diff --git a/modules/xadapter/adapter.py b/modules/xadapter/adapter.py index 69030fae3..4096f71e7 100644 --- a/modules/xadapter/adapter.py +++ b/modules/xadapter/adapter.py @@ -266,8 +266,6 @@ def forward(self, x, t=None): b, c, _, _ = x[-1].shape if t is not None: if not torch.is_tensor(t): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) is_mps = x[0].device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 diff --git a/modules/xadapter/unet_adapter.py b/modules/xadapter/unet_adapter.py index 5022f1847..5890c7749 100644 --- a/modules/xadapter/unet_adapter.py +++ b/modules/xadapter/unet_adapter.py @@ -807,8 +807,6 @@ def forward( # 1. time timesteps = timestep if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 @@ -1012,7 +1010,7 @@ def forward( if is_bridge: if up_block_additional_residual[0].shape != sample.shape: - pass # TODO VM patch + pass elif fusion_guidance_scale is not None: sample = sample + fusion_guidance_scale * (up_block_additional_residual.pop(0) - sample) else: @@ -1051,7 +1049,7 @@ def forward( ################# bridge usage ################# if is_bridge and len(up_block_additional_residual) > 0: if sample.shape != up_block_additional_residual[0].shape: - pass # TODO VM PATCH + pass elif fusion_guidance_scale is not None: sample = sample + fusion_guidance_scale * (up_block_additional_residual.pop(0) - sample) else: diff --git a/modules/zluda_installer.py b/modules/zluda_installer.py index 84c130e8d..5e43a6635 100644 --- a/modules/zluda_installer.py +++ b/modules/zluda_installer.py @@ -31,7 +31,10 @@ def install(zluda_path: os.PathLike) -> None: if os.path.exists(zluda_path): return - urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{os.environ.get("ZLUDA_HASH", "c0804ca624963aab420cb418412b1c7fbae3454b")}/ZLUDA-windows-rocm{rocm.version[0]}-amd64.zip', '_zluda') + commit = os.environ.get("ZLUDA_HASH", "1b6e012d8f2404840b524e2abae12cb91e1ac01d") + if rocm.version == "6.1": + commit = "c0804ca624963aab420cb418412b1c7fbae3454b" + urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{commit}/ZLUDA-windows-rocm{rocm.version[0]}-amd64.zip', '_zluda') with zipfile.ZipFile('_zluda', 'r') as archive: infos = archive.infolist() for info in infos: diff --git a/requirements.txt b/requirements.txt index 12a9f85cb..a67ad1ba2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,25 +34,25 @@ pi-heif # versioned safetensors==0.4.5 tensordict==0.1.2 -peft==0.13.1 +peft==0.14.0 httpx==0.24.1 compel==2.0.3 torchsde==0.2.6 antlr4-python3-runtime==4.9.3 requests==2.32.3 tqdm==4.66.5 -accelerate==1.1.1 +accelerate==1.2.1 opencv-contrib-python-headless==4.9.0.80 einops==0.4.1 gradio==3.43.2 -huggingface_hub==0.26.2 +huggingface_hub==0.27.0 numexpr==2.8.8 numpy==1.26.4 numba==0.59.1 protobuf==4.25.3 pytorch_lightning==1.9.4 -tokenizers==0.20.3 -transformers==4.46.2 +tokenizers==0.21.0 +transformers==4.47.1 urllib3==1.26.19 Pillow==10.4.0 timm==0.9.16 diff --git a/scripts/animatediff.py b/scripts/animatediff.py index 4c50f9cf6..f44c85bb7 100644 --- a/scripts/animatediff.py +++ b/scripts/animatediff.py @@ -182,14 +182,14 @@ def set_free_init(method, iters, order, spatial, temporal): def set_free_noise(frames): context_length = 16 context_stride = 4 - if frames >= context_length: + if frames >= context_length and hasattr(shared.sd_model, 'enable_free_noise'): shared.log.debug(f'AnimateDiff free noise: frames={frames} context={context_length} stride={context_stride}') shared.sd_model.enable_free_noise(context_length=context_length, context_stride=context_stride) class Script(scripts.Script): def title(self): - return 'Video AnimateDiff' + return 'Video: AnimateDiff' def show(self, is_img2img): # return scripts.AlwaysVisible if shared.native else False @@ -250,7 +250,7 @@ def run(self, p: processing.StableDiffusionProcessing, adapter_index, frames, lo processing.fix_seed(p) p.extra_generation_params['AnimateDiff'] = loaded_adapter p.do_not_save_grid = True - p.ops.append('animatediff') + p.ops.append('video') p.task_args['generator'] = None p.task_args['num_frames'] = frames p.task_args['num_inference_steps'] = p.steps diff --git a/scripts/cogvideo.py b/scripts/cogvideo.py index 7f2c7225e..e689a5e3f 100644 --- a/scripts/cogvideo.py +++ b/scripts/cogvideo.py @@ -22,7 +22,7 @@ class Script(scripts.Script): def title(self): - return 'Video CogVideoX' + return 'Video: CogVideoX' def show(self, is_img2img): return shared.native @@ -51,7 +51,7 @@ def video_type_change(video_type): with gr.Row(): video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') duration = gr.Slider(label='Duration', minimum=0.25, maximum=30, step=0.25, value=8, visible=False) - with gr.Accordion('Optional init video', open=False): + with gr.Accordion('Optional init image or video', open=False): with gr.Row(): image = gr.Image(value=None, label='Image', type='pil', source='upload', width=256, height=256) video = gr.Video(value=None, label='Video', source='upload', width=256, height=256) @@ -169,25 +169,18 @@ def generate(self, p: processing.StableDiffusionProcessing, model: str): callback_on_step_end=diffusers_callback, callback_on_step_end_tensor_inputs=['latents'], ) - if getattr(p, 'image', False): - if 'I2V' not in model: - shared.log.error(f'CogVideoX: model={model} image input not supported') - return [] - args['image'] = self.image(p, p.image) - args['num_frames'] = p.frames # only txt2vid has num_frames - shared.sd_model = sd_models.switch_pipe(diffusers.CogVideoXImageToVideoPipeline, shared.sd_model) - elif getattr(p, 'video', False): - if 'I2V' in model: - shared.log.error(f'CogVideoX: model={model} image input not supported') - return [] - args['video'] = self.video(p, p.video) - shared.sd_model = sd_models.switch_pipe(diffusers.CogVideoXVideoToVideoPipeline, shared.sd_model) + if 'I2V' in model: + if hasattr(p, 'video') and p.video is not None: + args['video'] = self.video(p, p.video) + shared.sd_model = sd_models.switch_pipe(diffusers.CogVideoXVideoToVideoPipeline, shared.sd_model) + elif (hasattr(p, 'image') and p.image is not None) or (hasattr(p, 'init_images') and len(p.init_images) > 0): + p.init_images = [p.image] if hasattr(p, 'image') and p.image is not None else p.init_images + args['image'] = self.image(p, p.init_images[0]) + shared.sd_model = sd_models.switch_pipe(diffusers.CogVideoXImageToVideoPipeline, shared.sd_model) else: - if 'I2V' in model: - shared.log.error(f'CogVideoX: model={model} image input not supported') - return [] - args['num_frames'] = p.frames # only txt2vid has num_frames shared.sd_model = sd_models.switch_pipe(diffusers.CogVideoXPipeline, shared.sd_model) + args['num_frames'] = p.frames # only txt2vid has num_frames + shared.log.info(f"CogVideoX: class={shared.sd_model.__class__.__name__} frames={p.frames} input={args.get('video', None) or args.get('image', None)}") if debug: shared.log.debug(f'CogVideoX args: {args}') frames = shared.sd_model(**args).frames[0] @@ -199,7 +192,7 @@ def generate(self, p: processing.StableDiffusionProcessing, model: str): errors.display(e, 'CogVideoX') t1 = time.time() its = (len(frames) * p.steps) / (t1 - t0) - shared.log.info(f'CogVideoX: frames={len(frames)} its={its:.2f} time={t1 - t0:.2f}') + shared.log.info(f'CogVideoX: frame={frames[0] if len(frames) > 0 else None} frames={len(frames)} its={its:.2f} time={t1 - t0:.2f}') return frames # auto-executed by the script-callback @@ -209,7 +202,7 @@ def run(self, p: processing.StableDiffusionProcessing, model, sampler, frames, g p.extra_generation_params['CogVideoX'] = model p.do_not_save_grid = True if 'animatediff' not in p.ops: - p.ops.append('cogvideox') + p.ops.append('video') if override: p.width = 720 p.height = 480 diff --git a/scripts/flux_tools.py b/scripts/flux_tools.py new file mode 100644 index 000000000..50904eedb --- /dev/null +++ b/scripts/flux_tools.py @@ -0,0 +1,149 @@ +# https://github.com/huggingface/diffusers/pull/9985 + +import time +import gradio as gr +import diffusers +from modules import scripts, processing, shared, devices, sd_models +from installer import install + + +# redux_pipe: diffusers.FluxPriorReduxPipeline = None +redux_pipe = None +processor_canny = None +processor_depth = None +title = 'Flux Tools' + + +class Script(scripts.Script): + def title(self): + return f'{title}' + + def show(self, is_img2img): + return is_img2img if shared.native else False + + def ui(self, _is_img2img): # ui elements + with gr.Row(): + gr.HTML('  Flux.1 Redux
') + with gr.Row(): + tool = gr.Dropdown(label='Tool', choices=['None', 'Redux', 'Fill', 'Canny', 'Depth'], value='None') + with gr.Row(): + prompt = gr.Slider(label='Redux prompt strength', minimum=0, maximum=2, step=0.01, value=0, visible=False) + process = gr.Checkbox(label='Control preprocess input images', value=True, visible=False) + strength = gr.Checkbox(label='Control override denoise strength', value=True, visible=False) + + def display(tool): + return [ + gr.update(visible=tool in ['Redux']), + gr.update(visible=tool in ['Canny', 'Depth']), + gr.update(visible=tool in ['Canny', 'Depth']), + ] + + tool.change(fn=display, inputs=[tool], outputs=[prompt, process, strength]) + return [tool, prompt, strength, process] + + def run(self, p: processing.StableDiffusionProcessing, tool: str = 'None', prompt: float = 1.0, strength: bool = True, process: bool = True): # pylint: disable=arguments-differ + global redux_pipe, processor_canny, processor_depth # pylint: disable=global-statement + if tool is None or tool == 'None': + return + supported_model_list = ['f1'] + if shared.sd_model_type not in supported_model_list: + shared.log.warning(f'{title}: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}') + return None + image = getattr(p, 'init_images', None) + if image is None or len(image) == 0: + shared.log.error(f'{title}: tool={tool} no init_images') + return None + else: + image = image[0] if isinstance(image, list) else image + + shared.log.info(f'{title}: tool={tool} init') + + t0 = time.time() + if tool == 'Redux': + # pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained("black-forest-labs/FLUX.1-Redux-dev", revision="refs/pr/8", torch_dtype=torch.bfloat16).to("cuda") + shared.log.debug(f'{title}: tool={tool} prompt={prompt}') + if redux_pipe is None: + redux_pipe = diffusers.FluxPriorReduxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Redux-dev", + revision="refs/pr/8", + torch_dtype=devices.dtype, + cache_dir=shared.opts.hfcache_dir + ).to(devices.device) + if prompt > 0: + shared.log.info(f'{title}: tool={tool} load text encoder') + redux_pipe.tokenizer, redux_pipe.tokenizer_2 = shared.sd_model.tokenizer, shared.sd_model.tokenizer_2 + redux_pipe.text_encoder, redux_pipe.text_encoder_2 = shared.sd_model.text_encoder, shared.sd_model.text_encoder_2 + sd_models.apply_balanced_offload(redux_pipe) + redux_output = redux_pipe( + image=image, + prompt=p.prompt if prompt > 0 else None, + prompt_embeds_scale=[prompt], + pooled_prompt_embeds_scale=[prompt], + ) + if prompt > 0: + redux_pipe.tokenizer, redux_pipe.tokenizer_2 = None, None + redux_pipe.text_encoder, redux_pipe.text_encoder_2 = None, None + devices.torch_gc() + for k, v in redux_output.items(): + p.task_args[k] = v + else: + if redux_pipe is not None: + shared.log.debug(f'{title}: tool=Redux unload') + redux_pipe = None + + if tool == 'Fill': + # pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, revision="refs/pr/4").to("cuda") + if p.image_mask is None: + shared.log.error(f'{title}: tool={tool} no image_mask') + return None + if shared.sd_model.__class__.__name__ != 'FluxFillPipeline': + shared.opts.data["sd_model_checkpoint"] = "black-forest-labs/FLUX.1-Fill-dev" + sd_models.reload_model_weights(op='model', revision="refs/pr/4") + p.task_args['image'] = image + p.task_args['mask_image'] = p.image_mask + + if tool == 'Canny': + # pipe = diffusers.FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16, revision="refs/pr/1").to("cuda") + install('controlnet-aux') + install('timm==0.9.16') + if shared.sd_model.__class__.__name__ != 'FluxControlPipeline' or 'Canny' not in shared.opts.sd_model_checkpoint: + shared.opts.data["sd_model_checkpoint"] = "black-forest-labs/FLUX.1-Canny-dev" + sd_models.reload_model_weights(op='model', revision="refs/pr/1") + if processor_canny is None: + from controlnet_aux import CannyDetector + processor_canny = CannyDetector() + if process: + control_image = processor_canny(image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024) + p.task_args['control_image'] = control_image + else: + p.task_args['control_image'] = image + if strength: + p.task_args['strength'] = None + else: + if processor_canny is not None: + shared.log.debug(f'{title}: tool=Canny unload processor') + processor_canny = None + + if tool == 'Depth': + # pipe = diffusers.FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16, revision="refs/pr/1").to("cuda") + install('git+https://github.com/huggingface/image_gen_aux.git', 'image_gen_aux') + if shared.sd_model.__class__.__name__ != 'FluxControlPipeline' or 'Depth' not in shared.opts.sd_model_checkpoint: + shared.opts.data["sd_model_checkpoint"] = "black-forest-labs/FLUX.1-Depth-dev" + sd_models.reload_model_weights(op='model', revision="refs/pr/1") + if processor_depth is None: + from image_gen_aux import DepthPreprocessor + processor_depth = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf") + if process: + control_image = processor_depth(image)[0].convert("RGB") + p.task_args['control_image'] = control_image + else: + p.task_args['control_image'] = image + if strength: + p.task_args['strength'] = None + else: + if processor_depth is not None: + shared.log.debug(f'{title}: tool=Depth unload processor') + processor_depth = None + + shared.log.debug(f'{title}: tool={tool} ready time={time.time() - t0:.2f}') + devices.torch_gc() diff --git a/scripts/freescale.py b/scripts/freescale.py new file mode 100644 index 000000000..672ceea41 --- /dev/null +++ b/scripts/freescale.py @@ -0,0 +1,130 @@ +import gradio as gr +from modules import scripts, processing, shared, sd_models + + +registered = False + + +class Script(scripts.Script): + def __init__(self): + super().__init__() + self.orig_pipe = None + self.orig_slice = None + self.orig_tile = None + self.is_img2img = False + + def title(self): + return 'FreeScale: Tuning-Free Scale Fusion' + + def show(self, is_img2img): + self.is_img2img = is_img2img + return shared.native + + def ui(self, _is_img2img): # ui elements + with gr.Row(): + gr.HTML('  FreeScale: Tuning-Free Scale Fusion
') + with gr.Row(): + cosine_scale = gr.Slider(minimum=0.1, maximum=5.0, value=2.0, label='Cosine scale') + override_sampler = gr.Checkbox(value=True, label='Override sampler') + with gr.Row(visible=self.is_img2img): + cosine_scale_bg = gr.Slider(minimum=0.1, maximum=5.0, value=1.0, label='Cosine Background') + dilate_tau = gr.Slider(minimum=1, maximum=100, value=35, label='Dilate tau') + with gr.Row(): + s1_enable = gr.Checkbox(value=True, label='1st Stage', interactive=False) + s1_scale = gr.Slider(minimum=1, maximum=8.0, value=1.0, label='Scale') + s1_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step') + with gr.Row(): + s2_enable = gr.Checkbox(value=True, label='2nd Stage') + s2_scale = gr.Slider(minimum=1, maximum=8.0, value=2.0, label='Scale') + s2_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step') + with gr.Row(): + s3_enable = gr.Checkbox(value=False, label='3rd Stage') + s3_scale = gr.Slider(minimum=1, maximum=8.0, value=3.0, label='Scale') + s3_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step') + with gr.Row(): + s4_enable = gr.Checkbox(value=False, label='4th Stage') + s4_scale = gr.Slider(minimum=1, maximum=8.0, value=4.0, label='Scale') + s4_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step') + return [cosine_scale, override_sampler, cosine_scale_bg, dilate_tau, s1_enable, s1_scale, s1_restart, s2_enable, s2_scale, s2_restart, s3_enable, s3_scale, s3_restart, s4_enable, s4_scale, s4_restart] + + def run(self, p: processing.StableDiffusionProcessing, cosine_scale, override_sampler, cosine_scale_bg, dilate_tau, s1_enable, s1_scale, s1_restart, s2_enable, s2_scale, s2_restart, s3_enable, s3_scale, s3_restart, s4_enable, s4_scale, s4_restart): # pylint: disable=arguments-differ + supported_model_list = ['sdxl'] + if shared.sd_model_type not in supported_model_list: + shared.log.warning(f'FreeScale: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}') + return None + + if self.is_img2img: + if p.init_images is None or len(p.init_images) == 0: + shared.log.warning('FreeScale: missing input image') + return None + + from modules.freescale import StableDiffusionXLFreeScale, StableDiffusionXLFreeScaleImg2Img + self.orig_pipe = shared.sd_model + self.orig_slice = shared.opts.diffusers_vae_slicing + self.orig_tile = shared.opts.diffusers_vae_tiling + + def scale(x): + if (p.width == 0 or p.height == 0) and p.init_images is not None: + p.width, p.height = p.init_images[0].width, p.init_images[0].height + resolution = [int(8 * p.width * x // 8), int(8 * p.height * x // 8)] + return resolution + + scales = [] + resolutions_list = [] + restart_steps = [] + if s1_enable: + scales.append(s1_scale) + resolutions_list.append(scale(s1_scale)) + restart_steps.append(int(p.steps * s1_restart)) + if s2_enable and s2_scale > s1_scale: + scales.append(s2_scale) + resolutions_list.append(scale(s2_scale)) + restart_steps.append(int(p.steps * s2_restart)) + if s3_enable and s3_scale > s2_scale: + scales.append(s3_scale) + resolutions_list.append(scale(s3_scale)) + restart_steps.append(int(p.steps * s3_restart)) + if s4_enable and s4_scale > s3_scale: + scales.append(s4_scale) + resolutions_list.append(scale(s4_scale)) + restart_steps.append(int(p.steps * s4_restart)) + + p.task_args['resolutions_list'] = resolutions_list + p.task_args['cosine_scale'] = cosine_scale + p.task_args['restart_steps'] = [min(max(1, step), p.steps-1) for step in restart_steps] + if self.is_img2img: + p.task_args['cosine_scale_bg'] = cosine_scale_bg + p.task_args['dilate_tau'] = dilate_tau + p.task_args['img_path'] = p.init_images[0] + p.init_images = None + if override_sampler: + p.sampler_name = 'Euler a' + + if p.width < 1024 or p.height < 1024: + shared.log.error(f'FreeScale: width={p.width} height={p.height} minimum=1024') + return None + + if not self.is_img2img: + shared.sd_model = sd_models.switch_pipe(StableDiffusionXLFreeScale, shared.sd_model) + else: + shared.sd_model = sd_models.switch_pipe(StableDiffusionXLFreeScaleImg2Img, shared.sd_model) + shared.sd_model.enable_vae_slicing() + shared.sd_model.enable_vae_tiling() + + shared.log.info(f'FreeScale: mode={"txt" if not self.is_img2img else "img"} cosine={cosine_scale} bg={cosine_scale_bg} tau={dilate_tau} scales={scales} resolutions={resolutions_list} steps={restart_steps} sampler={p.sampler_name}') + resolutions = ','.join([f'{x[0]}x{x[1]}' for x in resolutions_list]) + steps = ','.join([str(x) for x in restart_steps]) + p.extra_generation_params["FreeScale"] = f'cosine {cosine_scale} resolutions {resolutions} steps {steps}' + + def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, *args): # pylint: disable=arguments-differ, unused-argument + if self.orig_pipe is None: + return processed + # restore pipeline + if shared.sd_model_type == "sdxl": + shared.sd_model = self.orig_pipe + self.orig_pipe = None + if not self.orig_slice: + shared.sd_model.disable_vae_slicing() + if not self.orig_tile: + shared.sd_model.disable_vae_tiling() + return processed diff --git a/scripts/hunyuanvideo.py b/scripts/hunyuanvideo.py new file mode 100644 index 000000000..b94c8b8f8 --- /dev/null +++ b/scripts/hunyuanvideo.py @@ -0,0 +1,111 @@ +import time +import torch +import gradio as gr +import diffusers +from modules import scripts, processing, shared, images, devices, sd_models, sd_checkpoint, model_quant + + +repo_id = 'tencent/HunyuanVideo' +""" +prompt_template = { # default + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the contents, including objects, people, and anything else." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the contents." + "4. Background environment, light, style, atmosphere, and qualities." + "5. Camera angles, movements, and transitions used in the video." + "6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} +""" + + +class Script(scripts.Script): + def title(self): + return 'Video: Hunyuan Video' + + def show(self, is_img2img): + return not is_img2img if shared.native else False + + # return signature is array of gradio components + def ui(self, _is_img2img): + def video_type_change(video_type): + return [ + gr.update(visible=video_type != 'None'), + gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), + gr.update(visible=video_type == 'MP4'), + gr.update(visible=video_type == 'MP4'), + ] + + with gr.Row(): + gr.HTML('  Hunyuan Video
') + with gr.Row(): + num_frames = gr.Slider(label='Frames', minimum=9, maximum=257, step=1, value=45) + with gr.Row(): + video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') + duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) + with gr.Row(): + gif_loop = gr.Checkbox(label='Loop', value=True, visible=False) + mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) + mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) + video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) + return [num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate] + + def run(self, p: processing.StableDiffusionProcessing, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument + # set params + num_frames = int(num_frames) + p.width = 32 * int(p.width // 32) + p.height = 32 * int(p.height // 32) + p.task_args['output_type'] = 'pil' + p.task_args['generator'] = torch.manual_seed(p.seed) + p.task_args['num_frames'] = num_frames + # p.task_args['prompt_template'] = prompt_template + p.sampler_name = 'Default' + p.do_not_save_grid = True + p.ops.append('video') + + # load model + cls = diffusers.HunyuanVideoPipeline + if shared.sd_model.__class__ != cls: + sd_models.unload_model_weights() + kwargs = {} + kwargs = model_quant.create_bnb_config(kwargs) + kwargs = model_quant.create_ao_config(kwargs) + transformer = diffusers.HunyuanVideoTransformer3DModel.from_pretrained( + repo_id, + subfolder="transformer", + torch_dtype=devices.dtype, + revision="refs/pr/18", + cache_dir = shared.opts.hfcache_dir, + **kwargs + ) + shared.sd_model = cls.from_pretrained( + repo_id, + transformer=transformer, + revision="refs/pr/18", + cache_dir = shared.opts.hfcache_dir, + torch_dtype=devices.dtype, + **kwargs + ) + shared.sd_model.scheduler._shift = 7.0 # pylint: disable=protected-access + sd_models.set_diffuser_options(shared.sd_model) + shared.sd_model.sd_checkpoint_info = sd_checkpoint.CheckpointInfo(repo_id) + shared.sd_model.sd_model_hash = None + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) + shared.sd_model.vae.enable_slicing() + shared.sd_model.vae.enable_tiling() + devices.torch_gc(force=True) + shared.log.debug(f'Video: cls={shared.sd_model.__class__.__name__} args={p.task_args}') + + # run processing + t0 = time.time() + processed = processing.process_images(p) + t1 = time.time() + if processed is not None and len(processed.images) > 0: + shared.log.info(f'Video: frames={len(processed.images)} time={t1-t0:.2f}') + if video_type != 'None': + images.save_video(p, filename=None, images=processed.images, video_type=video_type, duration=duration, loop=gif_loop, pad=mp4_pad, interpolate=mp4_interpolate) + return processed diff --git a/scripts/image2video.py b/scripts/image2video.py index 876ed3193..ad6615f67 100644 --- a/scripts/image2video.py +++ b/scripts/image2video.py @@ -13,7 +13,7 @@ class Script(scripts.Script): def title(self): - return 'Video VGen Image-to-Video' + return 'Video: VGen Image-to-Video' def show(self, is_img2img): return is_img2img if shared.native else False @@ -73,7 +73,7 @@ def run(self, p: processing.StableDiffusionProcessing, model_name, num_frames, v model = [m for m in MODELS if m['name'] == model_name][0] repo_id = model['url'] shared.log.debug(f'Image2Video: model={model_name} frames={num_frames}, video={video_type} duration={duration} loop={gif_loop} pad={mp4_pad} interpolate={mp4_interpolate}') - p.ops.append('image2video') + p.ops.append('video') p.do_not_save_grid = True orig_pipeline = shared.sd_model diff --git a/scripts/instantir.py b/scripts/instantir.py index 5eb7d503a..6ab7733fe 100644 --- a/scripts/instantir.py +++ b/scripts/instantir.py @@ -80,7 +80,7 @@ def run(self, p: processing.StableDiffusionProcessing, *args): # pylint: disable devices.torch_gc() def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, *args): # pylint: disable=arguments-differ, unused-argument - # TODO instantir is a mess to unload + # TODO instantir: a mess to unload """ if self.orig_pipe is None: return processed diff --git a/scripts/ipadapter.py b/scripts/ipadapter.py index 60c70b9dc..5ca4ca578 100644 --- a/scripts/ipadapter.py +++ b/scripts/ipadapter.py @@ -57,6 +57,7 @@ def ui(self, _is_img2img): mask_galleries = [] with gr.Row(): num_adapters = gr.Slider(label="Active IP adapters", minimum=1, maximum=MAX_ADAPTERS, step=1, value=1, scale=1) + unload_adapter = gr.Checkbox(label='Unload adapter', value=False, interactive=True) for i in range(MAX_ADAPTERS): with gr.Accordion(f'Adapter {i+1}', visible=i==0) as unit: with gr.Row(): @@ -85,7 +86,7 @@ def ui(self, _is_img2img): layers_label = gr.HTML('InstantStyle: advanced layer activation', visible=False) layers = gr.Text(label='Layer scales', placeholder='{\n"down": {"block_2": [0.0, 1.0]},\n"up": {"block_0": [0.0, 1.0, 0.0]}\n}', rows=1, type='text', interactive=True, lines=5, visible=False, show_label=False) layers_active.change(fn=self.display_advanced, inputs=[layers_active], outputs=[layers_label, layers]) - return [num_adapters] + adapters + scales + files + crops + starts + ends + masks + [layers_active] + [layers] + return [num_adapters] + [unload_adapter] + adapters + scales + files + crops + starts + ends + masks + [layers_active] + [layers] def process(self, p: processing.StableDiffusionProcessing, *args): # pylint: disable=arguments-differ if not shared.native: @@ -94,6 +95,7 @@ def process(self, p: processing.StableDiffusionProcessing, *args): # pylint: dis if len(args) == 0: return units = args.pop(0) + unload = args.pop(0) if getattr(p, 'ip_adapter_names', []) == []: p.ip_adapter_names = args[:MAX_ADAPTERS][:units] if getattr(p, 'ip_adapter_scales', [0.0]) == [0.0]: @@ -110,6 +112,7 @@ def process(self, p: processing.StableDiffusionProcessing, *args): # pylint: dis p.ip_adapter_masks = args[MAX_ADAPTERS*6:MAX_ADAPTERS*7][:units] p.ip_adapter_masks = [x for x in p.ip_adapter_masks if x] layers_active, layers = args[MAX_ADAPTERS*7:MAX_ADAPTERS*8] + p.ip_adapter_unload = unload if layers_active and len(layers) > 0: try: layers = json.loads(layers) diff --git a/scripts/ltxvideo.py b/scripts/ltxvideo.py new file mode 100644 index 000000000..007c4f4cc --- /dev/null +++ b/scripts/ltxvideo.py @@ -0,0 +1,148 @@ +import os +import time +import torch +import gradio as gr +import diffusers +import transformers +from modules import scripts, processing, shared, images, devices, sd_models, sd_checkpoint, model_quant + + +repos = { + '0.9.0': 'a-r-r-o-w/LTX-Video-diffusers', + '0.9.1': 'a-r-r-o-w/LTX-Video-0.9.1-diffusers', + 'custom': None, +} + + +def load_quants(kwargs, repo_id): + if len(shared.opts.bnb_quantization) > 0: + quant_args = {} + quant_args = model_quant.create_bnb_config(quant_args) + quant_args = model_quant.create_ao_config(quant_args) + if not quant_args: + return kwargs + model_quant.load_bnb(f'Load model: type=LTX quant={quant_args}') + if 'Model' in shared.opts.bnb_quantization and 'transformer' not in kwargs: + kwargs['transformer'] = diffusers.LTXVideoTransformer3DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype, **quant_args) + shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}') + if 'Text Encoder' in shared.opts.bnb_quantization and 'text_encoder_3' not in kwargs: + kwargs['text_encoder'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder", cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype, **quant_args) + shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}') + return kwargs + + +def hijack_decode(*args, **kwargs): + shared.log.debug('Video: decode') + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) + return shared.sd_model.vae.orig_decode(*args, **kwargs) + + +class Script(scripts.Script): + def title(self): + return 'Video: LTX Video' + + def show(self, is_img2img): + return shared.native + + # return signature is array of gradio components + def ui(self, _is_img2img): + def video_type_change(video_type): + return [ + gr.update(visible=video_type != 'None'), + gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), + gr.update(visible=video_type == 'MP4'), + gr.update(visible=video_type == 'MP4'), + ] + def model_change(model): + return gr.update(visible=model == 'custom') + + with gr.Row(): + gr.HTML('  LTX Video
') + with gr.Row(): + model = gr.Dropdown(label='LTX Model', choices=list(repos), value='0.9.1') + decode = gr.Dropdown(label='Decode', choices=['diffusers', 'native'], value='diffusers', visible=False) + with gr.Row(): + num_frames = gr.Slider(label='Frames', minimum=9, maximum=257, step=1, value=41) + sampler = gr.Checkbox(label='Override sampler', value=True) + with gr.Row(): + model_custom = gr.Textbox(value='', label='Path to model file', visible=False) + with gr.Row(): + video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') + duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) + with gr.Row(): + gif_loop = gr.Checkbox(label='Loop', value=True, visible=False) + mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) + mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) + video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) + model.change(fn=model_change, inputs=[model], outputs=[model_custom]) + return [model, model_custom, decode, sampler, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate] + + def run(self, p: processing.StableDiffusionProcessing, model, model_custom, decode, sampler, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument + # set params + image = getattr(p, 'init_images', None) + image = None if image is None or len(image) == 0 else image[0] + if (p.width == 0 or p.height == 0) and image is not None: + p.width = image.width + p.height = image.height + num_frames = 8 * int(num_frames // 8) + 1 + p.width = 32 * int(p.width // 32) + p.height = 32 * int(p.height // 32) + processing.fix_seed(p) + if image: + image = images.resize_image(resize_mode=2, im=image, width=p.width, height=p.height, upscaler_name=None, output_type='pil') + p.task_args['image'] = image + p.task_args['output_type'] = 'latent' if decode == 'native' else 'pil' + p.task_args['generator'] = torch.Generator(devices.device).manual_seed(p.seed) + p.task_args['num_frames'] = num_frames + p.do_not_save_grid = True + if sampler: + p.sampler_name = 'Default' + p.ops.append('video') + + # load model + cls = diffusers.LTXPipeline if image is None else diffusers.LTXImageToVideoPipeline + diffusers.LTXTransformer3DModel = diffusers.LTXVideoTransformer3DModel + diffusers.AutoencoderKLLTX = diffusers.AutoencoderKLLTXVideo + repo_id = repos[model] + if repo_id is None: + repo_id = model_custom + if shared.sd_model.__class__ != cls: + sd_models.unload_model_weights() + kwargs = {} + kwargs = model_quant.create_bnb_config(kwargs) + kwargs = model_quant.create_ao_config(kwargs) + if os.path.isfile(repo_id): + shared.sd_model = cls.from_single_file( + repo_id, + cache_dir = shared.opts.hfcache_dir, + torch_dtype=devices.dtype, + **kwargs + ) + else: + kwargs = load_quants(kwargs, repo_id) + shared.sd_model = cls.from_pretrained( + repo_id, + cache_dir = shared.opts.hfcache_dir, + torch_dtype=devices.dtype, + **kwargs + ) + sd_models.set_diffuser_options(shared.sd_model) + shared.sd_model.vae.orig_decode = shared.sd_model.vae.decode + shared.sd_model.vae.decode = hijack_decode + shared.sd_model.sd_checkpoint_info = sd_checkpoint.CheckpointInfo(repo_id) + shared.sd_model.sd_model_hash = None + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) + shared.sd_model.vae.enable_slicing() + shared.sd_model.vae.enable_tiling() + devices.torch_gc(force=True) + shared.log.debug(f'Video: cls={shared.sd_model.__class__.__name__} args={p.task_args}') + + # run processing + t0 = time.time() + processed = processing.process_images(p) + t1 = time.time() + if processed is not None and len(processed.images) > 0: + shared.log.info(f'Video: frames={len(processed.images)} time={t1-t0:.2f}') + if video_type != 'None': + images.save_video(p, filename=None, images=processed.images, video_type=video_type, duration=duration, loop=gif_loop, pad=mp4_pad, interpolate=mp4_interpolate) + return processed diff --git a/scripts/mochivideo.py b/scripts/mochivideo.py new file mode 100644 index 000000000..f85616a5e --- /dev/null +++ b/scripts/mochivideo.py @@ -0,0 +1,85 @@ +import time +import torch +import gradio as gr +import diffusers +from modules import scripts, processing, shared, images, devices, sd_models, sd_checkpoint, model_quant + + +repo_id = 'genmo/mochi-1-preview' + + +class Script(scripts.Script): + def title(self): + return 'Video: Mochi.1 Video' + + def show(self, is_img2img): + return not is_img2img if shared.native else False + + # return signature is array of gradio components + def ui(self, _is_img2img): + def video_type_change(video_type): + return [ + gr.update(visible=video_type != 'None'), + gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), + gr.update(visible=video_type == 'MP4'), + gr.update(visible=video_type == 'MP4'), + ] + + with gr.Row(): + gr.HTML('  Mochi.1 Video
') + with gr.Row(): + num_frames = gr.Slider(label='Frames', minimum=9, maximum=257, step=1, value=45) + with gr.Row(): + video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None') + duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False) + with gr.Row(): + gif_loop = gr.Checkbox(label='Loop', value=True, visible=False) + mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False) + mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False) + video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate]) + return [num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate] + + def run(self, p: processing.StableDiffusionProcessing, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument + # set params + num_frames = int(num_frames // 8) + p.width = 32 * int(p.width // 32) + p.height = 32 * int(p.height // 32) + p.task_args['output_type'] = 'pil' + p.task_args['generator'] = torch.manual_seed(p.seed) + p.task_args['num_frames'] = num_frames + p.sampler_name = 'Default' + p.do_not_save_grid = True + p.ops.append('video') + + # load model + cls = diffusers.MochiPipeline + if shared.sd_model.__class__ != cls: + sd_models.unload_model_weights() + kwargs = {} + kwargs = model_quant.create_bnb_config(kwargs) + kwargs = model_quant.create_ao_config(kwargs) + shared.sd_model = cls.from_pretrained( + repo_id, + cache_dir = shared.opts.hfcache_dir, + torch_dtype=devices.dtype, + **kwargs + ) + shared.sd_model.scheduler._shift = 7.0 # pylint: disable=protected-access + sd_models.set_diffuser_options(shared.sd_model) + shared.sd_model.sd_checkpoint_info = sd_checkpoint.CheckpointInfo(repo_id) + shared.sd_model.sd_model_hash = None + shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model) + shared.sd_model.vae.enable_slicing() + shared.sd_model.vae.enable_tiling() + devices.torch_gc(force=True) + shared.log.debug(f'Video: cls={shared.sd_model.__class__.__name__} args={p.task_args}') + + # run processing + t0 = time.time() + processed = processing.process_images(p) + t1 = time.time() + if processed is not None and len(processed.images) > 0: + shared.log.info(f'Video: frames={len(processed.images)} time={t1-t0:.2f}') + if video_type != 'None': + images.save_video(p, filename=None, images=processed.images, video_type=video_type, duration=duration, loop=gif_loop, pad=mp4_pad, interpolate=mp4_interpolate) + return processed diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py index 3e5fd451d..8ed677736 100644 --- a/scripts/postprocessing_codeformer.py +++ b/scripts/postprocessing_codeformer.py @@ -10,9 +10,10 @@ class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing order = 3000 def ui(self): - with gr.Row(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Strength", value=0.0, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Weight", value=0.2, elem_id="extras_codeformer_weight") + with gr.Accordion('Restore faces: CodeFormer', open = False): + with gr.Row(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Strength", value=0.0, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Weight", value=0.2, elem_id="extras_codeformer_weight") return { "codeformer_visibility": codeformer_visibility, "codeformer_weight": codeformer_weight } def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): # pylint: disable=arguments-differ diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py index a69f97c9e..1e17c2b16 100644 --- a/scripts/postprocessing_gfpgan.py +++ b/scripts/postprocessing_gfpgan.py @@ -9,8 +9,9 @@ class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): order = 2000 def ui(self): - with gr.Row(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Strength", value=0, elem_id="extras_gfpgan_visibility") + with gr.Accordion('Restore faces: GFPGan', open = False): + with gr.Row(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Strength", value=0, elem_id="extras_gfpgan_visibility") return { "gfpgan_visibility": gfpgan_visibility } def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): # pylint: disable=arguments-differ diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index 0e07bc847..104a0fb37 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -10,43 +10,43 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): order = 1000 def ui(self): - selected_tab = gr.State(value=0) # pylint: disable=abstract-class-instantiated - - with gr.Column(): - with gr.Row(elem_id="extras_upscale"): - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: - upscaling_resize = gr.Slider(minimum=0.1, maximum=8.0, step=0.05, label="Resize", value=2.0, elem_id="extras_upscaling_resize") - - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: - with gr.Row(): - with gr.Row(elem_id="upscaling_column_size"): - upscaling_resize_w = gr.Slider(minimum=64, maximum=4096, step=8, label="Width", value=1024, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Slider(minimum=64, maximum=4096, step=8, label="Height", value=1024, elem_id="extras_upscaling_resize_h") - upscaling_res_switch_btn = ToolButton(value=symbols.switch, elem_id="upscaling_res_switch_btn") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Row(): - extras_upscaler_1 = gr.Dropdown(label='Upscaler', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) - - with gr.Row(): - extras_upscaler_2 = gr.Dropdown(label='Refine Upscaler', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") - - upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False) - tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) - tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) - - return { - "upscale_mode": selected_tab, - "upscale_by": upscaling_resize, - "upscale_to_width": upscaling_resize_w, - "upscale_to_height": upscaling_resize_h, - "upscale_crop": upscaling_crop, - "upscaler_1_name": extras_upscaler_1, - "upscaler_2_name": extras_upscaler_2, - "upscaler_2_visibility": extras_upscaler_2_visibility, - } + with gr.Accordion('Postprocess Upscale', open = False): + selected_tab = gr.State(value=0) # pylint: disable=abstract-class-instantiated + with gr.Column(): + with gr.Row(elem_id="extras_upscale"): + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: + upscaling_resize = gr.Slider(minimum=0.1, maximum=8.0, step=0.05, label="Resize", value=2.0, elem_id="extras_upscaling_resize") + + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: + with gr.Row(): + with gr.Row(elem_id="upscaling_column_size"): + upscaling_resize_w = gr.Slider(minimum=64, maximum=4096, step=8, label="Width", value=1024, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Slider(minimum=64, maximum=4096, step=8, label="Height", value=1024, elem_id="extras_upscaling_resize_h") + upscaling_res_switch_btn = ToolButton(value=symbols.switch, elem_id="upscaling_res_switch_btn") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with gr.Row(): + extras_upscaler_1 = gr.Dropdown(label='Upscaler', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + + with gr.Row(): + extras_upscaler_2 = gr.Dropdown(label='Refine Upscaler', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") + + upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False) + tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) + tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) + + return { + "upscale_mode": selected_tab, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + } def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop): if upscale_mode == 1: diff --git a/scripts/postprocessing_video.py b/scripts/postprocessing_video.py index 8c266639c..75d375572 100644 --- a/scripts/postprocessing_video.py +++ b/scripts/postprocessing_video.py @@ -7,40 +7,41 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): name = "Video" def ui(self): - def video_type_change(video_type): - return [ - gr.update(visible=video_type != 'None'), - gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - gr.update(visible=video_type == 'MP4'), - ] + with gr.Accordion('Create video', open = False): + def video_type_change(video_type): + return [ + gr.update(visible=video_type != 'None'), + gr.update(visible=video_type == 'GIF' or video_type == 'PNG'), + gr.update(visible=video_type == 'MP4'), + gr.update(visible=video_type == 'MP4'), + gr.update(visible=video_type == 'MP4'), + gr.update(visible=video_type == 'MP4'), + ] - with gr.Row(): - gr.HTML("  Video
") - with gr.Row(): - video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None', elem_id="extras_video_type") - duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False, elem_id="extras_video_duration") - with gr.Row(): - loop = gr.Checkbox(label='Loop', value=True, visible=False, elem_id="extras_video_loop") - pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False, elem_id="extras_video_pad") - interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False, elem_id="extras_video_interpolate") - scale = gr.Slider(label='Rescale', minimum=0.5, maximum=2, step=0.05, value=1, visible=False, elem_id="extras_video_scale") - change = gr.Slider(label='Frame change sensitivity', minimum=0, maximum=1, step=0.05, value=0.3, visible=False, elem_id="extras_video_change") - with gr.Row(): - filename = gr.Textbox(label='Filename', placeholder='enter filename', lines=1, elem_id="extras_video_filename") - video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, loop, pad, interpolate, scale, change]) - return { - "filename": filename, - "video_type": video_type, - "duration": duration, - "loop": loop, - "pad": pad, - "interpolate": interpolate, - "scale": scale, - "change": change, - } + with gr.Row(): + gr.HTML("  Video
") + with gr.Row(): + video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None', elem_id="extras_video_type") + duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False, elem_id="extras_video_duration") + with gr.Row(): + loop = gr.Checkbox(label='Loop', value=True, visible=False, elem_id="extras_video_loop") + pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False, elem_id="extras_video_pad") + interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False, elem_id="extras_video_interpolate") + scale = gr.Slider(label='Rescale', minimum=0.5, maximum=2, step=0.05, value=1, visible=False, elem_id="extras_video_scale") + change = gr.Slider(label='Frame change sensitivity', minimum=0, maximum=1, step=0.05, value=0.3, visible=False, elem_id="extras_video_change") + with gr.Row(): + filename = gr.Textbox(label='Filename', placeholder='enter filename', lines=1, elem_id="extras_video_filename") + video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, loop, pad, interpolate, scale, change]) + return { + "filename": filename, + "video_type": video_type, + "duration": duration, + "loop": loop, + "pad": pad, + "interpolate": interpolate, + "scale": scale, + "change": change, + } def postprocess(self, images, filename, video_type, duration, loop, pad, interpolate, scale, change): # pylint: disable=arguments-differ filename = filename.strip() if filename is not None else '' diff --git a/scripts/pulid_ext.py b/scripts/pulid_ext.py index 676fa79f3..ee08e348b 100644 --- a/scripts/pulid_ext.py +++ b/scripts/pulid_ext.py @@ -164,11 +164,13 @@ def run( p.batch_size = 1 sdp = shared.opts.cross_attention_optimization == "Scaled-Dot-Product" + sampler_fn = getattr(self.pulid.sampling, f'sample_{sampler}', None) strength = getattr(p, 'pulid_strength', strength) zero = getattr(p, 'pulid_zero', zero) ortho = getattr(p, 'pulid_ortho', ortho) sampler = getattr(p, 'pulid_sampler', sampler) - sampler_fn = getattr(self.pulid.sampling, f'sample_{sampler}', None) + restore = getattr(p, 'pulid_restore', restore) + p.pulid_restore = restore if sampler_fn is None: sampler_fn = self.pulid.sampling.sample_dpmpp_2m_sde @@ -199,7 +201,7 @@ def run( return None shared.sd_model.sampler = sampler_fn - shared.log.info(f'PuLID: class={shared.sd_model.__class__.__name__} version="{version}" sdp={sdp} strength={strength} zero={zero} ortho={ortho} sampler={sampler_fn} images={[i.shape for i in images]} offload={offload}') + shared.log.info(f'PuLID: class={shared.sd_model.__class__.__name__} version="{version}" sdp={sdp} strength={strength} zero={zero} ortho={ortho} sampler={sampler_fn} images={[i.shape for i in images]} offload={offload} restore={restore}') self.pulid.attention.NUM_ZERO = zero self.pulid.attention.ORTHO = ortho == 'v1' self.pulid.attention.ORTHO_v2 = ortho == 'v2' diff --git a/scripts/regional_prompting.py b/scripts/regional_prompting.py index 08b84dd94..3d452b0c5 100644 --- a/scripts/regional_prompting.py +++ b/scripts/regional_prompting.py @@ -9,7 +9,7 @@ def hijack_register_modules(self, **kwargs): for name, module in kwargs.items(): register_dict = None - if module is None or isinstance(module, (tuple, list)) and module[0] is None: + if module is None or (isinstance(module, (tuple, list)) and module[0] is None): register_dict = {name: (None, None)} elif isinstance(module, bool): pass @@ -82,6 +82,7 @@ def run(self, p: processing.StableDiffusionProcessing, mode, grid, power, thresh } # run pipeline shared.log.debug(f'Regional: args={p.task_args}') + p.task_args['prompt'] = p.prompt processed: processing.Processed = processing.process_images(p) # runs processing using main loop # restore pipeline and params diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 9c5a72204..7ac31b603 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -48,7 +48,7 @@ def run(self, p, _, overlap, upscaler_index, scale_factor): # pylint: disable=ar else: img = init_img devices.torch_gc() - grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap) + grid = images.split_grid(img, tile_w=init_img.width, tile_h=init_img.height, overlap=overlap) batch_size = p.batch_size upscale_count = p.n_iter p.n_iter = 1 @@ -61,7 +61,7 @@ def run(self, p, _, overlap, upscaler_index, scale_factor): # pylint: disable=ar batch_count = math.ceil(len(work) / batch_size) state.job_count = batch_count * upscale_count - log.info(f"SD upscale: images={len(work)} tile={len(grid.tiles[0][2])}x{len(grid.tiles)} batches={state.job_count}") + log.info(f"SD upscale: images={len(work)} tiles={len(grid.tiles)} batches={state.job_count}") result_images = [] for n in range(upscale_count): @@ -91,4 +91,5 @@ def run(self, p, _, overlap, upscaler_index, scale_factor): # pylint: disable=ar images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p) processed = Processed(p, result_images, seed, initial_info) + log.info(f"SD upscale: images={result_images}") return processed diff --git a/scripts/stablevideodiffusion.py b/scripts/stablevideodiffusion.py index cbf2ce003..c1283e1b6 100644 --- a/scripts/stablevideodiffusion.py +++ b/scripts/stablevideodiffusion.py @@ -16,7 +16,7 @@ class Script(scripts.Script): def title(self): - return 'Video: SVD' + return 'Video: Stable Video Diffusion' def show(self, is_img2img): return is_img2img if shared.native else False @@ -75,13 +75,17 @@ def run(self, p: processing.StableDiffusionProcessing, model, num_frames, overri if model_name != model_loaded or c != 'StableVideoDiffusionPipeline': shared.opts.sd_model_checkpoint = model_path sd_models.reload_model_weights() + shared.sd_model = shared.sd_model.to(torch.float32) # TODO svd: runs in fp32 causing dtype mismatch # set params if override_resolution: p.width = 1024 p.height = 576 image = images.resize_image(resize_mode=2, im=image, width=p.width, height=p.height, upscaler_name=None, output_type='pil') - p.ops.append('svd') + else: + p.width = image.width + p.height = image.height + p.ops.append('video') p.do_not_save_grid = True p.init_images = [image] p.sampler_name = 'Default' # svd does not support non-default sampler diff --git a/scripts/style_aligned.py b/scripts/style_aligned.py new file mode 100644 index 000000000..25feb49bc --- /dev/null +++ b/scripts/style_aligned.py @@ -0,0 +1,117 @@ +import gradio as gr +import torch +import numpy as np +import diffusers +from modules import scripts, processing, shared, devices + + +handler = None +zts = None +supported_model_list = ['sdxl'] +orig_prompt_attention = None + + +class Script(scripts.Script): + def title(self): + return 'Style Aligned Image Generation' + + def show(self, is_img2img): + return shared.native + + def reset(self): + global handler, zts # pylint: disable=global-statement + handler = None + zts = None + shared.log.info('SA: image upload') + + def preset(self, preset): + if preset == 'text': + return [['attention', 'adain_queries', 'adain_keys'], 1.0, 0, 0.0] + if preset == 'image': + return [['group_norm', 'layer_norm', 'attention', 'adain_queries', 'adain_keys'], 1.0, 2, 0.0] + if preset == 'all': + return [['group_norm', 'layer_norm', 'attention', 'adain_queries', 'adain_keys', 'adain_values', 'full_attention_share'], 1.0, 1, 0.5] + + def ui(self, _is_img2img): # ui elements + with gr.Row(): + gr.HTML('  Style Aligned Image Generation

') + with gr.Row(): + preset = gr.Dropdown(label="Preset", choices=['text', 'image', 'all'], value='text') + scheduler = gr.Checkbox(label="Override scheduler", value=False) + with gr.Row(): + shared_opts = gr.Dropdown(label="Shared options", + multiselect=True, + choices=['group_norm', 'layer_norm', 'attention', 'adain_queries', 'adain_keys', 'adain_values', 'full_attention_share'], + value=['attention', 'adain_queries', 'adain_keys'], + ) + with gr.Row(): + shared_score_scale = gr.Slider(label="Scale", minimum=0.0, maximum=2.0, step=0.01, value=1.0) + shared_score_shift = gr.Slider(label="Shift", minimum=0, maximum=10, step=1, value=0) + only_self_level = gr.Slider(label="Level", minimum=0.0, maximum=1.0, step=0.01, value=0.0) + with gr.Row(): + prompt = gr.Textbox(lines=1, label='Optional image description', placeholder='use the style from the image') + with gr.Row(): + image = gr.Image(label='Optional image', source='upload', type='pil') + + image.change(self.reset) + preset.change(self.preset, inputs=[preset], outputs=[shared_opts, shared_score_scale, shared_score_shift, only_self_level]) + + return [image, prompt, scheduler, shared_opts, shared_score_scale, shared_score_shift, only_self_level] + + def run(self, p: processing.StableDiffusionProcessing, image, prompt, scheduler, shared_opts, shared_score_scale, shared_score_shift, only_self_level): # pylint: disable=arguments-differ + global handler, zts, orig_prompt_attention # pylint: disable=global-statement + if shared.sd_model_type not in supported_model_list: + shared.log.warning(f'SA: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}') + return None + + from modules.style_aligned import sa_handler, inversion + + handler = sa_handler.Handler(shared.sd_model) + sa_args = sa_handler.StyleAlignedArgs( + share_group_norm='group_norm' in shared_opts, + share_layer_norm='layer_norm' in shared_opts, + share_attention='attention' in shared_opts, + adain_queries='adain_queries' in shared_opts, + adain_keys='adain_keys' in shared_opts, + adain_values='adain_values' in shared_opts, + full_attention_share='full_attention_share' in shared_opts, + shared_score_scale=float(shared_score_scale), + shared_score_shift=np.log(shared_score_shift) if shared_score_shift > 0 else 0, + only_self_level=1 if only_self_level else 0, + ) + handler.register(sa_args) + + if scheduler: + shared.sd_model.scheduler = diffusers.DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) + p.sampler_name = 'None' + + if image is not None and zts is None: + shared.log.info(f'SA: inversion image={image} prompt="{prompt}"') + image = image.resize((1024, 1024)) + x0 = np.array(image).astype(np.float32) / 255.0 + shared.sd_model.scheduler = diffusers.DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) + zts = inversion.ddim_inversion(shared.sd_model, x0, prompt, num_inference_steps=50, guidance_scale=2) + + p.prompt = p.prompt.splitlines() + p.batch_size = len(p.prompt) + orig_prompt_attention = shared.opts.prompt_attention + shared.opts.data['prompt_attention'] = 'fixed' # otherwise need to deal with class_tokens_mask + + if zts is not None: + processing.fix_seed(p) + zT, inversion_callback = inversion.make_inversion_callback(zts, offset=0) + generator = torch.Generator(device='cpu') + generator.manual_seed(p.seed) + latents = torch.randn(p.batch_size, 4, 128, 128, device='cpu', generator=generator, dtype=devices.dtype,).to(devices.device) + latents[0] = zT + p.task_args['latents'] = latents + p.task_args['callback_on_step_end'] = inversion_callback + + shared.log.info(f'SA: batch={p.batch_size} type={"image" if zts is not None else "text"} config={sa_args.__dict__}') + + def after(self, p: processing.StableDiffusionProcessing, *args): # pylint: disable=unused-argument + global handler # pylint: disable=global-statement + if handler is not None: + handler.remove() + handler = None + shared.opts.data['prompt_attention'] = orig_prompt_attention diff --git a/scripts/text2video.py b/scripts/text2video.py index dc4c44cac..c7b3d1c05 100644 --- a/scripts/text2video.py +++ b/scripts/text2video.py @@ -87,7 +87,7 @@ def run(self, p: processing.StableDiffusionProcessing, model_name, use_default, shared.opts.sd_model_checkpoint = checkpoint.name sd_models.reload_model_weights(op='model') - p.ops.append('text2video') + p.ops.append('video') p.do_not_save_grid = True if use_default: p.task_args['num_frames'] = model['params'][0] diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 60e608c76..bb067ea21 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -12,6 +12,7 @@ from scripts.xyz_grid_shared import str_permutations, list_to_csv_string, re_range # pylint: disable=no-name-in-module from scripts.xyz_grid_classes import axis_options, AxisOption, SharedSettingsStackHelper # pylint: disable=no-name-in-module from scripts.xyz_grid_draw import draw_xyz_grid # pylint: disable=no-name-in-module +from scripts.xyz_grid_shared import apply_field, apply_task_args, apply_setting, apply_prompt, apply_order, apply_sampler, apply_hr_sampler_name, confirm_samplers, apply_checkpoint, apply_refiner, apply_unet, apply_dict, apply_clip_skip, apply_vae, list_lora, apply_lora, apply_lora_strength, apply_te, apply_styles, apply_upscaler, apply_context, apply_detailer, apply_override, apply_processing, apply_options, apply_seed, format_value_add_label, format_value, format_value_join_list, do_nothing, format_nothing # pylint: disable=no-name-in-module, unused-import from modules import shared, errors, scripts, images, processing from modules.ui_components import ToolButton import modules.ui_symbols as symbols diff --git a/scripts/xyz_grid_classes.py b/scripts/xyz_grid_classes.py index b80b9f13c..cc70d68f8 100644 --- a/scripts/xyz_grid_classes.py +++ b/scripts/xyz_grid_classes.py @@ -1,4 +1,4 @@ -from scripts.xyz_grid_shared import apply_field, apply_task_args, apply_setting, apply_prompt, apply_order, apply_sampler, apply_hr_sampler_name, confirm_samplers, apply_checkpoint, apply_refiner, apply_unet, apply_dict, apply_clip_skip, apply_vae, list_lora, apply_lora, apply_te, apply_styles, apply_upscaler, apply_context, apply_detailer, apply_override, apply_processing, apply_options, apply_seed, format_value_add_label, format_value, format_value_join_list, do_nothing, format_nothing, str_permutations # pylint: disable=no-name-in-module, unused-import +from scripts.xyz_grid_shared import apply_field, apply_task_args, apply_setting, apply_prompt, apply_order, apply_sampler, apply_hr_sampler_name, confirm_samplers, apply_checkpoint, apply_refiner, apply_unet, apply_dict, apply_clip_skip, apply_vae, list_lora, apply_lora, apply_lora_strength, apply_te, apply_styles, apply_upscaler, apply_context, apply_detailer, apply_override, apply_processing, apply_options, apply_seed, format_value_add_label, format_value, format_value_join_list, do_nothing, format_nothing, str_permutations # pylint: disable=no-name-in-module, unused-import from modules import shared, shared_items, sd_samplers, ipadapter, sd_models, sd_vae, sd_unet @@ -97,7 +97,7 @@ def __exit__(self, exc_type, exc_value, tb): AxisOption("[Prompt] Prompt order", str_permutations, apply_order, fmt=format_value_join_list), AxisOption("[Prompt] Prompt parser", str, apply_setting("prompt_attention"), choices=lambda: ["native", "compel", "xhinker", "a1111", "fixed"]), AxisOption("[Network] LoRA", str, apply_lora, cost=0.5, choices=list_lora), - AxisOption("[Network] LoRA strength", float, apply_setting('extra_networks_default_multiplier')), + AxisOption("[Network] LoRA strength", float, apply_lora_strength, cost=0.6), AxisOption("[Network] Styles", str, apply_styles, choices=lambda: [s.name for s in shared.prompt_styles.styles.values()]), AxisOption("[Param] Width", int, apply_field("width")), AxisOption("[Param] Height", int, apply_field("height")), diff --git a/scripts/xyz_grid_on.py b/scripts/xyz_grid_on.py index 202a2cfc4..aa0897442 100644 --- a/scripts/xyz_grid_on.py +++ b/scripts/xyz_grid_on.py @@ -413,6 +413,7 @@ def cell(x, y, z, ix, iy, iz): p.do_not_save_grid = True p.do_not_save_samples = True + p.disable_extra_networks = True active = False cache = processed return processed diff --git a/scripts/xyz_grid_shared.py b/scripts/xyz_grid_shared.py index d3ee0a864..f9bf26c67 100644 --- a/scripts/xyz_grid_shared.py +++ b/scripts/xyz_grid_shared.py @@ -63,28 +63,15 @@ def apply_seed(p, x, xs): def apply_prompt(p, x, xs): - if not hasattr(p, 'orig_prompt'): - p.orig_prompt = p.prompt - p.orig_negative = p.negative_prompt - if xs[0] not in p.orig_prompt and xs[0] not in p.orig_negative: - shared.log.warning(f'XYZ grid: prompt S/R string="{xs[0]}" not found') - else: - p.prompt = p.orig_prompt.replace(xs[0], x) - p.negative_prompt = p.orig_negative.replace(xs[0], x) - p.all_prompts = None - p.all_negative_prompts = None - """ - if p.all_prompts is not None: - for i in range(len(p.all_prompts)): - for j in range(len(xs)): - p.all_prompts[i] = p.all_prompts[i].replace(xs[j], x) - p.negative_prompt = p.negative_prompt.replace(xs[0], x) - if p.all_negative_prompts is not None: - for i in range(len(p.all_negative_prompts)): - for j in range(len(xs)): - p.all_negative_prompts[i] = p.all_negative_prompts[i].replace(xs[j], x) - """ - shared.log.debug(f'XYZ grid apply prompt: "{xs[0]}"="{x}"') + for s in xs: + if s in p.prompt: + shared.log.debug(f'XYZ grid apply prompt: "{s}"="{x}"') + p.prompt = p.prompt.replace(s, x) + if s in p.negative_prompt: + shared.log.debug(f'XYZ grid apply negative: "{s}"="{x}"') + p.negative_prompt = p.negative_prompt.replace(s, x) + p.all_prompts = None + p.all_negative_prompts = None def apply_order(p, x, xs): @@ -205,19 +192,28 @@ def apply_vae(p, x, xs): def list_lora(): import sys - lora = [v for k, v in sys.modules.items() if k == 'networks'][0] + lora = [v for k, v in sys.modules.items() if k == 'networks' or k == 'modules.lora.networks'][0] loras = [v.fullname for v in lora.available_networks.values()] return ['None'] + loras def apply_lora(p, x, xs): + p.all_prompts = None + p.all_negative_prompts = None if x == 'None': return x = os.path.basename(x) p.prompt = p.prompt + f" " + shared.log.debug(f'XYZ grid apply LoRA: "{x}"') + + +def apply_lora_strength(p, x, xs): + shared.log.debug(f'XYZ grid apply LoRA strength: "{x}"') + p.prompt = p.prompt.replace(':1.0>', '>') + p.prompt = p.prompt.replace(f':{shared.opts.extra_networks_default_multiplier}>', '>') p.all_prompts = None p.all_negative_prompts = None - shared.log.debug(f'XYZ grid apply LoRA: "{x}"') + shared.opts.data['extra_networks_default_multiplier'] = x def apply_te(p, x, xs): diff --git a/webui.py b/webui.py index 9684ef7c8..4eb6e89ce 100644 --- a/webui.py +++ b/webui.py @@ -8,6 +8,7 @@ import importlib import contextlib from threading import Thread +import modules.hashes import modules.loader import torch # pylint: disable=wrong-import-order from modules import timer, errors, paths # pylint: disable=unused-import @@ -18,6 +19,7 @@ from modules.paths import create_paths from modules.call_queue import queue_lock, wrap_queued_call, wrap_gradio_gpu_call # pylint: disable=unused-import import modules.devices +import modules.sd_checkpoint import modules.sd_samplers import modules.lowvram import modules.scripts @@ -77,6 +79,9 @@ def check_rollback_vae(): def initialize(): log.debug('Initializing') + + modules.sd_checkpoint.init_metadata() + modules.hashes.init_cache() check_rollback_vae() modules.sd_samplers.list_samplers() @@ -89,15 +94,20 @@ def initialize(): timer.startup.record("unet") modules.model_te.refresh_te_list() - timer.startup.record("unet") - - extensions.list_extensions() - timer.startup.record("extensions") + timer.startup.record("te") modelloader.cleanup_models() modules.sd_models.setup_model() timer.startup.record("models") + if shared.native: + import modules.lora.networks as lora_networks + lora_networks.list_available_networks() + timer.startup.record("lora") + + shared.prompt_styles.reload() + timer.startup.record("styles") + import modules.postprocess.codeformer_model as codeformer codeformer.setup_model(shared.opts.codeformer_models_path) sys.modules["modules.codeformer_model"] = codeformer @@ -107,6 +117,9 @@ def initialize(): yolo.initialize() timer.startup.record("detailer") + extensions.list_extensions() + timer.startup.record("extensions") + log.info('Load extensions') t_timer, t_total = modules.scripts.load_scripts() timer.startup.record("extensions") @@ -116,8 +129,9 @@ def initialize(): modelloader.load_upscalers() timer.startup.record("upscalers") - shared.reload_hypernetworks() - shared.prompt_styles.reload() + if shared.opts.hypernetwork_enabled: + shared.reload_hypernetworks() + timer.startup.record("hypernetworks") ui_extra_networks.initialize() ui_extra_networks.register_pages() diff --git a/wiki b/wiki index 30f3265bb..56ba782f7 160000 --- a/wiki +++ b/wiki @@ -1 +1 @@ -Subproject commit 30f3265bb06ac738e4467f58be4df3fc4b49c08b +Subproject commit 56ba782f744bb8f6928f6c365d6ffc547d339548