diff --git a/.eslintrc.json b/.eslintrc.json
index 62feb13a5..c86dbb749 100644
--- a/.eslintrc.json
+++ b/.eslintrc.json
@@ -37,14 +37,19 @@
"object-curly-newline":"off",
"prefer-rest-params":"off",
"prefer-destructuring":"off",
- "radix":"off"
+ "radix":"off",
+ "node/shebang": "off"
},
"globals": {
// asssets
"panzoom": "readonly",
- // script.js
+ // logger.js
"log": "readonly",
"debug": "readonly",
+ "error": "readonly",
+ "xhrGet": "readonly",
+ "xhrPost": "readonly",
+ // script.js
"gradioApp": "readonly",
"executeCallbacks": "readonly",
"onAfterUiUpdate": "readonly",
@@ -87,7 +92,6 @@
// settings.js
"registerDragDrop": "readonly",
// extraNetworks.js
- "requestGet": "readonly",
"getENActiveTab": "readonly",
"quickApplyStyle": "readonly",
"quickSaveStyle": "readonly",
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
index cf176d6cf..a40320c63 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -106,10 +106,14 @@ body:
- StableDiffusion 1.5
- StableDiffusion 2.1
- StableDiffusion XL
- - StableDiffusion 3
- - PixArt
+ - StableDiffusion 3.x
- StableCascade
+ - FLUX.1
+ - PixArt
- Kandinsky
+ - Playground
+ - AuraFlow
+ - Any Video Model
- Other
default: 0
validations:
diff --git a/.pylintrc b/.pylintrc
index 59f1cb127..ad42ddd13 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -13,6 +13,7 @@ ignore-paths=/usr/lib/.*$,
modules/control/units,
modules/ctrlx,
modules/dml,
+ modules/freescale,
modules/ggml,
modules/hidiffusion,
modules/hijack,
diff --git a/.ruff.toml b/.ruff.toml
index c2d4a6f9a..4bab64260 100644
--- a/.ruff.toml
+++ b/.ruff.toml
@@ -7,6 +7,7 @@ exclude = [
"modules/consistory",
"modules/control/proc",
"modules/control/units",
+ "modules/freescale",
"modules/ggml",
"modules/hidiffusion",
"modules/hijack",
diff --git a/CHANGELOG.md b/CHANGELOG.md
index bc6cd163b..b778dca21 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,21 +1,230 @@
# Change Log for SD.Next
-## Update for 2024-11-22
-
-- Model loader improvements:
+## Update for 2024-12-24
+
+### Highlights for 2024-12-24
+
+### SD.Next Xmass edition: *What's new?*
+
+While we have several new supported models, workflows and tools, this release is primarily about *quality-of-life improvements*:
+- New memory management engine
+ list of changes that went into this one is long: changes to GPU offloading, brand new LoRA loader, system memory management, on-the-fly quantization, improved gguf loader, etc.
+ but main goal is enabling modern large models to run on standard consumer GPUs
+ without performance hits typically associated with aggressive memory swapping and needs for constant manual tweaks
+- New [documentation website](https://vladmandic.github.io/sdnext-docs/)
+ with full search and tons of new documentation
+- New settings panel with simplified and streamlined configuration
+
+We've also added support for several new models such as highly anticipated [NVLabs Sana](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px) (see [supported models](https://vladmandic.github.io/sdnext-docs/Model-Support/) for full list)
+And several new SOTA video models: [Lightricks LTX-Video](https://huggingface.co/Lightricks/LTX-Video), [Hunyuan Video](https://huggingface.co/tencent/HunyuanVideo) and [Genmo Mochi.1 Preview](https://huggingface.co/genmo/mochi-1-preview)
+
+And a lot of **Control** and **IPAdapter** goodies
+- for **SDXL** there is new [ProMax](https://huggingface.co/xinsir/controlnet-union-sdxl-1.0), improved *Union* and *Tiling* models
+- for **FLUX.1** there are [Flux Tools](https://blackforestlabs.ai/flux-1-tools/) as well as official *Canny* and *Depth* models,
+ a cool [Redux](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) model as well as [XLabs](https://huggingface.co/XLabs-AI/flux-ip-adapter-v2) IP-adapter
+- for **SD3.5** there are official *Canny*, *Blur* and *Depth* models in addition to existing 3rd party models
+ as well as [InstantX](https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter) IP-adapter
+
+Plus couple of new integrated workflows such as [FreeScale](https://github.com/ali-vilab/FreeScale) and [Style Aligned Image Generation](https://style-aligned-gen.github.io/)
+
+And it wouldn't be a *Xmass edition* without couple of custom themes: *Snowflake* and *Elf-Green*!
+All-in-all, we're around ~180 commits worth of updates, check the changelog for full list
+
+[ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867)
+
+## Details for 2024-12-24
+
+### New models and integrations
+
+- [NVLabs Sana](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px)
+ support for 1.6B 2048px, 1.6B 1024px and 0.6B 512px models
+ **Sana** can synthesize high-resolution images with strong text-image alignment by using **Gemma2** as text-encoder
+ and its *fast* - typically at least **2x** faster than sd-xl even for 1.6B variant and maintains performance regardless of resolution
+ e.g., rendering at 4k is possible in less than 8GB vram
+ to use, select from *networks -> models -> reference* and models will be auto-downloaded on first use
+ *reference values*: sampler: default (or any flow-match variant), steps: 20, width/height: 1024, guidance scale: 4.5
+ *note* like other LLM-based text-encoders, sana prefers long and descriptive prompts
+ any short prompt below 300 characters will be auto-expanded using built in Gemma LLM before encoding while long prompts will be passed as-is
+- **ControlNet**
+ - improved support for **Union** controlnets with granular control mode type
+ - added support for latest [Xinsir ProMax](https://huggingface.co/xinsir/controlnet-union-sdxl-1.0) all-in-one controlnet
+ - added support for multiple **Tiling** controlnets, for example [Xinsir Tile](https://huggingface.co/xinsir/controlnet-tile-sdxl-1.0)
+ *note*: when selecting tiles in control settings, you can also specify non-square ratios
+ in which case it will use context-aware image resize to maintain overall composition
+ *note*: available tiling options can be set in settings -> control
+- **IP-Adapter**
+ - FLUX.1 [XLabs](https://huggingface.co/XLabs-AI/flux-ip-adapter-v2) v1 and v2 IP-adapter
+ - FLUX.1 secondary guidance, enabled using *Attention guidance* in advanced menu
+ - SD 3.5 [InstantX](https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter) IP-adapter
+- [Flux Tools](https://blackforestlabs.ai/flux-1-tools/)
+ **Redux** is actually a tool, **Fill** is inpaint/outpaint optimized version of *Flux-dev*
+ **Canny** & **Depth** are optimized versions of *Flux-dev* for their respective tasks: they are *not* ControlNets that work on top of a model
+ to use, go to image or control interface and select *Flux Tools* in scripts
+ all models are auto-downloaded on first use
+ *note*: All models are [gated](https://github.com/vladmandic/automatic/wiki/Gated) and require acceptance of terms and conditions via web page
+ *recommended*: Enable on-the-fly [quantization](https://github.com/vladmandic/automatic/wiki/Quantization) or [compression](https://github.com/vladmandic/automatic/wiki/NNCF-Compression) to reduce resource usage
+ *todo*: support for Canny/Depth LoRAs
+ - [Redux](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev): ~0.1GB
+ works together with existing model and basically uses input image to analyze it and use that instead of prompt
+ *optional* can use prompt to combine guidance with input image
+ *recommended*: low denoise strength levels result in more variety
+ - [Fill](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev): ~23.8GB, replaces currently loaded model
+ *note*: can be used in inpaint/outpaint mode only
+ - [Canny](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev): ~23.8GB, replaces currently loaded model
+ *recommended*: guidance scale 30
+ - [Depth](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev): ~23.8GB, replaces currently loaded model
+ *recommended*: guidance scale 10
+- [Flux ControlNet LoRA](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora)
+ alternative to standard ControlNets, FLUX.1 also allows LoRA to help guide the generation process
+ both **Depth** and **Canny** LoRAs are available in standard control menus
+- [StabilityAI SD35 ControlNets](https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets)
+ - In addition to previously released `InstantX` and `Alimama`, we now have *official* ones from StabilityAI
+- [Style Aligned Image Generation](https://style-aligned-gen.github.io/)
+ enable in scripts, compatible with sd-xl
+ enter multiple prompts in prompt field separated by new line
+ style-aligned applies selected attention layers uniformly to all images to achive consistency
+ can be used with or without input image in which case first prompt is used to establish baseline
+ *note:* all prompts are processes as a single batch, so vram is limiting factor
+- [FreeScale](https://github.com/ali-vilab/FreeScale)
+ enable in scripts, compatible with sd-xl for text and img2img
+ run iterative generation of images at different scales to achieve better results
+ can render 4k sdxl images
+ *note*: disable live preview to avoid memory issues when generating large images
+
+### Video models
+
+- [Lightricks LTX-Video](https://huggingface.co/Lightricks/LTX-Video)
+ model size: 27.75gb
+ support for 0.9.0, 0.9.1 and custom safetensor-based models with full quantization and offloading support
+ support for text-to-video and image-to-video, to use, select in *scripts -> ltx-video*
+ *refrence values*: steps 50, width 704, height 512, frames 161, guidance scale 3.0
+- [Hunyuan Video](https://huggingface.co/tencent/HunyuanVideo)
+ model size: 40.92gb
+ support for text-to-video, to use, select in *scripts -> hunyuan video*
+ basic support only
+ *refrence values*: steps 50, width 1280, height 720, frames 129, guidance scale 6.0
+- [Genmo Mochi.1 Preview](https://huggingface.co/genmo/mochi-1-preview)
+ support for text-to-video, to use, select in *scripts -> mochi.1 video*
+ basic support only
+ *refrence values*: steps 64, width 848, height 480, frames 19, guidance scale 4.5
+
+*Notes*:
+- all video models are very large and resource intensive!
+ any use on gpus below 16gb and systems below 48gb ram is experimental at best
+- sdnext support for video models is relatively basic with further optimizations pending community interest
+ any future optimizations would likely have to go into partial loading and excecution instead of offloading inactive parts of the model
+- new video models use generic llms for prompting and due to that requires very long and descriptive prompt
+- you may need to enable sequential offload for maximum gpu memory savings
+- optionally enable pre-quantization using bnb for additional memory savings
+- reduce number of frames and/or resolution to reduce memory usage
+
+### UI and workflow improvements
+
+- **Docs**:
+ - New documentation site!
+ - Additional Wiki content: Styles, Wildcards, etc.
+- **LoRA** handler rewrite:
+ - LoRA weights are no longer calculated on-the-fly during model execution, but are pre-calculated at the start
+ this results in perceived overhead on generate startup, but results in overall faster execution as LoRA does not need to be processed on each step
+ thanks @AI-Casanova
+ - LoRA weights can be applied/unapplied as on each generate or they can store weights backups for later use
+ this setting has large performance and resource implications, see [Offload](https://github.com/vladmandic/automatic/wiki/Offload) wiki for details
+ - LoRA name in prompt can now also be an absolute path to a LoRA file, even if LoRA is not indexed
+ example: ``
+ - LoRA name in prompt can now also be path to a LoRA file op `huggingface`
+ example: ``
+- **Model loader** improvements:
- detect model components on model load fail
+ - allow passing absolute path to model loader
- Flux, SD35: force unload model
- Flux: apply `bnb` quant when loading *unet/transformer*
- Flux: all-in-one safetensors
example:
- Flux: do not recast quants
-- Sampler improvements
- - update DPM FlowMatch samplers
-- Fixes:
- - update `diffusers`
- - fix README links
- - fix sdxl controlnet single-file loader
- - relax settings validator
+- **Memory** improvements:
+ - faster and more compatible *balanced offload* mode
+ - balanced offload: units are now in percentage instead of bytes
+ - balanced offload: add both high and low watermark, defaults as below
+ `0.25` for low-watermark: skip offload if memory usage is below 25%
+ `0.70` high-watermark: must offload if memory usage is above 70%
+ - balanced offload will attempt to run offload as non-blocking and force gc at the end
+ - change-in-behavior:
+ low-end systems, triggered by either `lowvrwam` or by detection of <=4GB will use *sequential offload*
+ all other systems use *balanced offload* by default (can be changed in settings)
+ previous behavior was to use *model offload* on systems with <=8GB and `medvram` and no offload by default
+ - VAE upcase is now disabled by default on all systems
+ if you have issues with image decode, you'll need to enable it manually
+- **UI**:
+ - improved stats on generate completion
+ - improved live preview display and performance
+ - improved accordion behavior
+ - auto-size networks height for sidebar
+ - control: hide preview column by default
+ - control: optionn to hide input column
+ - control: add stats
+ - settings: reorganized and simplified
+ - browser -> server logging framework
+ - add addtional themes: `black-reimagined`, thanks @Artheriax
+- **Batch**
+ - image batch processing will use caption files if they exist instead of default prompt
+
+### Updates
+
+- **Quantization**
+ - Add `TorchAO` *pre* (during load) and *post* (during execution) quantization
+ **torchao** supports 4 different int-based and 3 float-based quantization schemes
+ This is in addition to existing support for:
+ - `BitsAndBytes` with 3 float-based quantization schemes
+ - `Optimium.Quanto` with 3 int-based and 2 float-based quantizations schemes
+ - `GGUF` with pre-quantized weights
+ - Switch `GGUF` loader from custom to diffuser native
+- **IPEX**: update to IPEX 2.5.10+xpu
+- **OpenVINO**:
+ - update to 2024.6.0
+ - disable model caching by default
+- **Sampler** improvements
+ - UniPC, DEIS, SA, DPM-Multistep: allow FlowMatch sigma method and prediction type
+ - Euler FlowMatch: add sigma methods (*karras/exponential/betas*)
+ - Euler FlowMatch: allow using timestep presets to set sigmas
+ - DPM FlowMatch: update all and add sigma methods
+ - BDIA-DDIM: *experimental* new scheduler
+ - UFOGen: *experimental* new scheduler
+
+### Fixes
+
+- add `SD_NO_CACHE=true` env variable to disable file/folder caching
+- add settings -> networks -> embeddings -> enable/disable
+- update `diffusers`
+- fix README links
+- fix sdxl controlnet single-file loader
+- relax settings validator
+- improve js progress calls resiliency
+- fix text-to-video pipeline
+- avoid live-preview if vae-decode is running
+- allow xyz-grid with multi-axis s&r
+- fix xyz-grid with lora
+- fix api script callbacks
+- fix gpu memory monitoring
+- simplify img2img/inpaint/sketch canvas handling
+- fix prompt caching
+- fix xyz grid skip final pass
+- fix sd upscale script
+- fix cogvideox-i2v
+- lora auto-apply tags remove duplicates
+- control load model on-demand if not already loaded
+- taesd limit render to 2024px
+- taesd downscale preview to 1024px max: configurable in settings -> live preview
+- uninstall conflicting `wandb` package
+- dont skip diffusers version check if quick is specified
+- notify on torch install
+- detect pipeline fro diffusers folder-style model
+- do not recast flux quants
+- fix xyz-grid with lora none
+- fix svd image2video
+- fix gallery display during generate
+- fix wildcards replacement to be unique
+- fix animatediff-xl
+- fix pag with batch count
## Update for 2024-11-21
@@ -270,7 +479,7 @@ A month later and with nearly 300 commits, here is the latest [SD.Next](https://
#### New models for 2024-10-23
-- New fine-tuned [CLiP-ViT-L]((https://huggingface.co/zer0int/CLIP-GmP-ViT-L-14)) 1st stage **text-encoders** used by most models (SD15/SDXL/SD3/Flux/etc.) brings additional details to your images
+- New fine-tuned [CLiP-ViT-L](https://huggingface.co/zer0int/CLIP-GmP-ViT-L-14) 1st stage **text-encoders** used by most models (SD15/SDXL/SD3/Flux/etc.) brings additional details to your images
- New models:
[Stable Diffusion 3.5 Large](https://huggingface.co/stabilityai/stable-diffusion-3.5-large)
[OmniGen](https://arxiv.org/pdf/2409.11340)
@@ -370,7 +579,7 @@ And there are also other goodies like multiple *XYZ grid* improvements, addition
- xyz grid support for sampler options
- metadata updates for sampler options
- modernui updates for sampler options
- - *note* sampler options defaults are not save in ui settings, they are saved in server settings
+ - *note* sampler options defaults are not saved in ui settings, they are saved in server settings
to apply your defaults, set ui values and apply via *system -> settings -> apply settings*
*sampler options*:
@@ -602,7 +811,7 @@ Examples:
- vae is list of manually downloaded safetensors
- text-encoder is list of predefined and manually downloaded text-encoders
- **controlnet** support:
- support for **InstantX/Shakker-Labs** models including [Union-Pro](InstantX/FLUX.1-dev-Controlnet-Union)
+ support for **InstantX/Shakker-Labs** models including [Union-Pro](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union)
note that flux controlnet models are large, up to 6.6GB on top of already large base model!
as such, you may need to use offloading:sequential which is not as fast, but uses far less memory
when using union model, you must also select control mode in the control unit
@@ -2117,7 +2326,7 @@ Also new is support for **SDXL-Turbo** as well as new **Kandinsky 3** models and
- in *Advanced* params
- allows control of *latent clamping*, *color centering* and *range maximization*
- supported by *XYZ grid*
- - [SD21 Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL Turbo]() support
+ - [SD21 Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support
- just set CFG scale (0.0-1.0) and steps (1-3) to a very low value
- compatible with original StabilityAI SDXL-Turbo or any of the newer merges
- download safetensors or select from networks -> reference
diff --git a/README.md b/README.md
index 1644281ad..2a95373d0 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
-
+
**Image Diffusion implementation with advanced features**
@@ -8,15 +8,16 @@
[![Discord](https://img.shields.io/discord/1101998836328697867?logo=Discord&svg=true)](https://discord.gg/VjvR2tabEX)
[![Sponsors](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/vladmandic)
-[Wiki](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.gg/VjvR2tabEX) | [Changelog](CHANGELOG.md)
+[Docs](https://vladmandic.github.io/sdnext-docs/) | [Wiki](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.gg/VjvR2tabEX) | [Changelog](CHANGELOG.md)
## Table of contents
+- [Documentation](https://vladmandic.github.io/sdnext-docs/)
- [SD.Next Features](#sdnext-features)
-- [Model support](#model-support)
+- [Model support](#model-support) and [Specifications]()
- [Platform support](#platform-support)
- [Getting started](#getting-started)
@@ -25,7 +26,7 @@
All individual features are not listed here, instead check [ChangeLog](CHANGELOG.md) for full list of changes
- Multiple UIs!
▹ **Standard | Modern**
-- Multiple diffusion models!
+- Multiple [diffusion models](https://vladmandic.github.io/sdnext-docs/Model-Support/)!
- Built-in Control for Text, Image, Batch and video processing!
- Multiplatform!
▹ **Windows | Linux | MacOS | nVidia | AMD | IntelArc/IPEX | DirectML | OpenVINO | ONNX+Olive | ZLUDA**
@@ -34,9 +35,7 @@ All individual features are not listed here, instead check [ChangeLog](CHANGELOG
- Platform specific autodetection and tuning performed on install
- Optimized processing with latest `torch` developments with built-in support for `torch.compile`
and multiple compile backends: *Triton, ZLUDA, StableFast, DeepCache, OpenVINO, NNCF, IPEX, OneDiff*
-- Improved prompt parser
- Built-in queue management
-- Enterprise level logging and hardened API
- Built in installer with automatic updates and dependency management
- Mobile compatible
@@ -49,42 +48,13 @@ All individual features are not listed here, instead check [ChangeLog](CHANGELOG
![screenshot-modernui](https://github.com/user-attachments/assets/39e3bc9a-a9f7-4cda-ba33-7da8def08032)
-For screenshots and informations on other available themes, see [Themes Wiki](https://github.com/vladmandic/automatic/wiki/Themes)
+For screenshots and informations on other available themes, see [Themes](https://vladmandic.github.io/sdnext-docs/Themes/)
## Model support
-Additional models will be added as they become available and there is public interest in them
-See [models overview](https://github.com/vladmandic/automatic/wiki/Models) for details on each model, including their architecture, complexity and other info
-
-- [RunwayML Stable Diffusion](https://github.com/Stability-AI/stablediffusion/) 1.x and 2.x *(all variants)*
-- [StabilityAI Stable Diffusion XL](https://github.com/Stability-AI/generative-models), [StabilityAI Stable Diffusion 3.0](https://stability.ai/news/stable-diffusion-3-medium) Medium, [StabilityAI Stable Diffusion 3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) Medium, Large, Large Turbo
-- [StabilityAI Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid) Base, XT 1.0, XT 1.1
-- [StabilityAI Stable Cascade](https://github.com/Stability-AI/StableCascade) *Full* and *Lite*
-- [Black Forest Labs FLUX.1](https://blackforestlabs.ai/announcing-black-forest-labs/) Dev, Schnell
-- [AuraFlow](https://huggingface.co/fal/AuraFlow)
-- [AlphaVLLM Lumina-Next-SFT](https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT-diffusers)
-- [Playground AI](https://huggingface.co/playgroundai/playground-v2-256px-base) *v1, v2 256, v2 512, v2 1024 and latest v2.5*
-- [Tencent HunyuanDiT](https://github.com/Tencent/HunyuanDiT)
-- [OmniGen](https://arxiv.org/pdf/2409.11340)
-- [Meissonic](https://github.com/viiika/Meissonic)
-- [Kwai Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
-- [CogView 3+](https://huggingface.co/THUDM/CogView3-Plus-3B)
-- [LCM: Latent Consistency Models](https://github.com/openai/consistency_models)
-- [aMUSEd](https://huggingface.co/amused/amused-256) 256 and 512
-- [Segmind Vega](https://huggingface.co/segmind/Segmind-Vega), [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B), [Segmind SegMoE](https://github.com/segmind/segmoe) *SD and SD-XL*, [Segmind SD Distilled](https://huggingface.co/blog/sd_distillation) *(all variants)*
-- [Kandinsky](https://github.com/ai-forever/Kandinsky-2) *2.1 and 2.2 and latest 3.0*
-- [PixArt-α XL 2](https://github.com/PixArt-alpha/PixArt-alpha) *Medium and Large*, [PixArt-Σ](https://github.com/PixArt-alpha/PixArt-sigma)
-- [Warp Wuerstchen](https://huggingface.co/blog/wuertschen)
-- [Tsinghua UniDiffusion](https://github.com/thu-ml/unidiffuser)
-- [DeepFloyd IF](https://github.com/deep-floyd/IF) *Medium and Large*
-- [ModelScope T2V](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b)
-- [BLIP-Diffusion](https://dxli94.github.io/BLIP-Diffusion-website/)
-- [KOALA 700M](https://github.com/youngwanLEE/sdxl-koala)
-- [VGen](https://huggingface.co/ali-vilab/i2vgen-xl)
-- [SDXS](https://github.com/IDKiro/sdxs)
-- [Hyper-SD](https://huggingface.co/ByteDance/Hyper-SD)
+SD.Next supports broad range of models: [supported models](https://vladmandic.github.io/sdnext-docs/Model-Support/) and [model specs](https://vladmandic.github.io/sdnext-docs/Models/)
## Platform support
@@ -97,47 +67,29 @@ See [models overview](https://github.com/vladmandic/automatic/wiki/Models) for d
- Any GPU or device compatible with **OpenVINO** libraries on both *Windows and Linux*
- *Apple M1/M2* on *OSX* using built-in support in Torch with **MPS** optimizations
- *ONNX/Olive*
-- *AMD* GPUs on Windows using **ZLUDA** libraries
+- *AMD* GPUs on Windows using **ZLUDA** libraries
## Getting started
-- Get started with **SD.Next** by following the [installation instructions](https://github.com/vladmandic/automatic/wiki/Installation)
-- For more details, check out [advanced installation](https://github.com/vladmandic/automatic/wiki/Advanced-Install) guide
-- List and explanation of [command line arguments](https://github.com/vladmandic/automatic/wiki/CLI-Arguments)
+- Get started with **SD.Next** by following the [installation instructions](https://vladmandic.github.io/sdnext-docs/Installation/)
+- For more details, check out [advanced installation](https://vladmandic.github.io/sdnext-docs/Advanced-Install/) guide
+- List and explanation of [command line arguments](https://vladmandic.github.io/sdnext-docs/CLI-Arguments/)
- Install walkthrough [video](https://www.youtube.com/watch?v=nWTnTyFTuAs)
> [!TIP]
> And for platform specific information, check out
-> [WSL](https://github.com/vladmandic/automatic/wiki/WSL) | [Intel Arc](https://github.com/vladmandic/automatic/wiki/Intel-ARC) | [DirectML](https://github.com/vladmandic/automatic/wiki/DirectML) | [OpenVINO](https://github.com/vladmandic/automatic/wiki/OpenVINO) | [ONNX & Olive](https://github.com/vladmandic/automatic/wiki/ONNX-Runtime) | [ZLUDA](https://github.com/vladmandic/automatic/wiki/ZLUDA) | [AMD ROCm](https://github.com/vladmandic/automatic/wiki/AMD-ROCm) | [MacOS](https://github.com/vladmandic/automatic/wiki/MacOS-Python.md) | [nVidia](https://github.com/vladmandic/automatic/wiki/nVidia)
+> [WSL](https://vladmandic.github.io/sdnext-docs/WSL/) | [Intel Arc](https://vladmandic.github.io/sdnext-docs/Intel-ARC/) | [DirectML](https://vladmandic.github.io/sdnext-docs/DirectML/) | [OpenVINO](https://vladmandic.github.io/sdnext-docs/OpenVINO/) | [ONNX & Olive](https://vladmandic.github.io/sdnext-docs/ONNX-Runtime/) | [ZLUDA](https://vladmandic.github.io/sdnext-docs/ZLUDA/) | [AMD ROCm](https://vladmandic.github.io/sdnext-docs/AMD-ROCm/) | [MacOS](https://vladmandic.github.io/sdnext-docs/MacOS-Python/) | [nVidia](https://vladmandic.github.io/sdnext-docs/nVidia/) | [Docker](https://vladmandic.github.io/sdnext-docs/Docker/)
> [!WARNING]
-> If you run into issues, check out [troubleshooting](https://github.com/vladmandic/automatic/wiki/Troubleshooting) and [debugging](https://github.com/vladmandic/automatic/wiki/Debug) guides
+> If you run into issues, check out [troubleshooting](https://vladmandic.github.io/sdnext-docs/Troubleshooting/) and [debugging](https://vladmandic.github.io/sdnext-docs/Debug/) guides
> [!TIP]
-> All command line options can also be set via env variable
+> All command line options can also be set via env variable
> For example `--debug` is same as `set SD_DEBUG=true`
-## Backend support
-
-**SD.Next** supports two main backends: *Diffusers* and *Original*:
-
-- **Diffusers**: Based on new [Huggingface Diffusers](https://huggingface.co/docs/diffusers/index) implementation
- Supports *all* models listed below
- This backend is set as default for new installations
-- **Original**: Based on [LDM](https://github.com/Stability-AI/stablediffusion) reference implementation and significantly expanded on by [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
- This backend and is fully compatible with most existing functionality and extensions written for *A1111 SDWebUI*
- Supports **SD 1.x** and **SD 2.x** models
- All other model types such as *SD-XL, LCM, Stable Cascade, PixArt, Playground, Segmind, Kandinsky, etc.* require backend **Diffusers**
-
-### Collab
-
-- We'd love to have additional maintainers (with comes with full repo rights). If you're interested, ping us!
-- In addition to general cross-platform code, desire is to have a lead for each of the main platforms
-This should be fully cross-platform, but we'd really love to have additional contributors and/or maintainers to join and help lead the efforts on different platforms
-
### Credits
-- Main credit goes to [Automatic1111 WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for original codebase
+- Main credit goes to [Automatic1111 WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for the original codebase
- Additional credits are listed in [Credits](https://github.com/AUTOMATIC1111/stable-diffusion-webui/#credits)
- Licenses for modules are listed in [Licenses](html/licenses.html)
@@ -154,8 +106,8 @@ This should be fully cross-platform, but we'd really love to have additional con
### Docs
-If you're unsure how to use a feature, best place to start is [Wiki](https://github.com/vladmandic/automatic/wiki) and if its not there,
-check [ChangeLog](CHANGELOG.md) for when feature was first introduced as it will always have a short note on how to use it
+If you're unsure how to use a feature, best place to start is [Docs](https://vladmandic.github.io/sdnext-docs/) and if its not there,
+check [ChangeLog](https://vladmandic.github.io/sdnext-docs/CHANGELOG/) for when feature was first introduced as it will always have a short note on how to use it
### Sponsors
diff --git a/TODO.md b/TODO.md
index 973e062dc..6d89c838f 100644
--- a/TODO.md
+++ b/TODO.md
@@ -2,21 +2,30 @@
Main ToDo list can be found at [GitHub projects](https://github.com/users/vladmandic/projects)
-## Future Candidates
+## Pending
-- SD35 IPAdapter:
-- SD35 LoRA:
-- Flux IPAdapter:
-- Flux Fill/ControlNet/Redux:
-- Flux NF4:
-- SANA:
+- LoRA direct with caching
+- Previewer issues
+- Redesign postprocessing
-## Other
+## Future Candidates
+- Flux NF4 loader:
- IPAdapter negative:
- Control API enhance scripts compatibility
+- PixelSmith:
-## Workaround in place
+## Code TODO
-- GGUF
-- FlowMatch
+- TODO install: python 3.12.4 or higher cause a mess with pydantic
+- TODO install: enable ROCm for windows when available
+- TODO resize image: enable full VAE mode for resize-latent
+- TODO processing: remove duplicate mask params
+- TODO flux: fix loader for civitai nf4 models
+- TODO model loader: implement model in-memory caching
+- TODO hypertile: vae breaks when using non-standard sizes
+- TODO model load: force-reloading entire model as loading transformers only leads to massive memory usage
+- TODO lora load: direct with bnb
+- TODO lora make: support quantized flux
+- TODO control: support scripts via api
+- TODO modernui: monkey-patch for missing tabs.select event
diff --git a/cli/api-model.js b/cli/api-model.js
new file mode 100755
index 000000000..e2ce5344a
--- /dev/null
+++ b/cli/api-model.js
@@ -0,0 +1,30 @@
+#!/usr/bin/env node
+
+const sd_url = process.env.SDAPI_URL || 'http://127.0.0.1:7860';
+const sd_username = process.env.SDAPI_USR;
+const sd_password = process.env.SDAPI_PWD;
+const models = [
+ '/mnt/models/stable-diffusion/sd15/lyriel_v16.safetensors',
+ '/mnt/models/stable-diffusion/flux/flux-finesse_v2-f1h-fp8.safetensors',
+ '/mnt/models/stable-diffusion/sdxl/TempestV0.1-Artistic.safetensors',
+];
+
+async function options(data) {
+ const method = 'POST';
+ const headers = new Headers();
+ const body = JSON.stringify(data);
+ headers.set('Content-Type', 'application/json');
+ if (sd_username && sd_password) headers.set({ Authorization: `Basic ${btoa('sd_username:sd_password')}` });
+ const res = await fetch(`${sd_url}/sdapi/v1/options`, { method, headers, body });
+ return res;
+}
+
+async function main() {
+ for (const model of models) {
+ console.log('model:', model);
+ const res = await options({ sd_model_checkpoint: model });
+ console.log('result:', res);
+ }
+}
+
+main();
diff --git a/cli/api-pulid.js b/cli/api-pulid.js
index fde0ae43b..033824e9b 100755
--- a/cli/api-pulid.js
+++ b/cli/api-pulid.js
@@ -10,12 +10,13 @@ const argparse = require('argparse');
const sd_url = process.env.SDAPI_URL || 'http://127.0.0.1:7860';
const sd_username = process.env.SDAPI_USR;
const sd_password = process.env.SDAPI_PWD;
+let args = {};
function b64(file) {
const data = fs.readFileSync(file);
- const b64 = Buffer.from(data).toString('base64');
+ const b64str = Buffer.from(data).toString('base64');
const ext = path.extname(file).replace('.', '');
- str = `data:image/${ext};base64,${b64}`;
+ const str = `data:image/${ext};base64,${b64str}`;
// console.log('b64:', ext, b64.length);
return str;
}
@@ -39,7 +40,16 @@ function options() {
if (args.pulid) {
const b64image = b64(args.pulid);
opt.script_name = 'pulid';
- opt.script_args = [b64image, 0.9];
+ opt.script_args = [
+ b64image, // b64 encoded image, required param
+ 0.9, // strength, optional
+ 20, // zero, optional
+ 'dpmpp_sde', // sampler, optional
+ 'v2', // ortho, optional
+ true, // restore (disable pulid after run), optional
+ true, // offload, optional
+ 'v1.1', // version, optional
+ ];
}
// console.log('options:', opt);
return opt;
@@ -53,8 +63,8 @@ function init() {
parser.add_argument('--height', { type: 'int', help: 'height' });
parser.add_argument('--pulid', { type: 'str', help: 'pulid init image' });
parser.add_argument('--output', { type: 'str', help: 'output path' });
- const args = parser.parse_args();
- return args
+ const parsed = parser.parse_args();
+ return parsed;
}
async function main() {
@@ -73,12 +83,12 @@ async function main() {
console.log('result:', json.info);
for (const i in json.images) { // eslint-disable-line guard-for-in
const file = args.output || `/tmp/test-${i}.jpg`;
- const data = atob(json.images[i])
+ const data = atob(json.images[i]);
fs.writeFileSync(file, data, 'binary');
console.log('image saved:', file);
}
}
}
-const args = init();
+args = init();
main();
diff --git a/cli/full-test.sh b/cli/full-test.sh
index e410528ad..912dc3a5b 100755
--- a/cli/full-test.sh
+++ b/cli/full-test.sh
@@ -1,5 +1,8 @@
#!/usr/bin/env bash
+node cli/api-txt2img.js
+node cli/api-pulid.js
+
source venv/bin/activate
echo image-exif
python cli/api-info.py --input html/logo-bg-0.jpg
diff --git a/cli/load-unet.py b/cli/load-unet.py
index 2398cdb64..c910101b0 100644
--- a/cli/load-unet.py
+++ b/cli/load-unet.py
@@ -33,13 +33,13 @@ def set_module_tensor(
stats.dtypes[value.dtype] = 0
stats.dtypes[value.dtype] += 1
if name in module._buffers: # pylint: disable=protected-access
- module._buffers[name] = value.to(device=device, dtype=dtype, non_blocking=True) # pylint: disable=protected-access
+ module._buffers[name] = value.to(device=device, dtype=dtype) # pylint: disable=protected-access
if 'buffers' not in stats.weights:
stats.weights['buffers'] = 0
stats.weights['buffers'] += 1
elif value is not None:
param_cls = type(module._parameters[name]) # pylint: disable=protected-access
- module._parameters[name] = param_cls(value, requires_grad=old_value.requires_grad).to(device, dtype=dtype, non_blocking=True) # pylint: disable=protected-access
+ module._parameters[name] = param_cls(value, requires_grad=old_value.requires_grad).to(device, dtype=dtype) # pylint: disable=protected-access
if 'parameters' not in stats.weights:
stats.weights['parameters'] = 0
stats.weights['parameters'] += 1
diff --git a/configs/flux/vae/config.json b/configs/flux/vae/config.json
index b43183d0f..7ecb342c2 100644
--- a/configs/flux/vae/config.json
+++ b/configs/flux/vae/config.json
@@ -14,7 +14,7 @@
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
- "force_upcast": true,
+ "force_upcast": false,
"in_channels": 3,
"latent_channels": 16,
"latents_mean": null,
diff --git a/configs/sd15/vae/config.json b/configs/sd15/vae/config.json
index 55d78924f..2cba0e824 100644
--- a/configs/sd15/vae/config.json
+++ b/configs/sd15/vae/config.json
@@ -14,6 +14,7 @@
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
+ "force_upcast": false,
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
diff --git a/configs/sd3/vae/config.json b/configs/sd3/vae/config.json
index 58e7764fb..f6f4e8684 100644
--- a/configs/sd3/vae/config.json
+++ b/configs/sd3/vae/config.json
@@ -15,7 +15,7 @@
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
- "force_upcast": true,
+ "force_upcast": false,
"in_channels": 3,
"latent_channels": 16,
"latents_mean": null,
diff --git a/configs/sdxl/vae/config.json b/configs/sdxl/vae/config.json
index a66a171ba..1c7a60866 100644
--- a/configs/sdxl/vae/config.json
+++ b/configs/sdxl/vae/config.json
@@ -15,7 +15,7 @@
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
- "force_upcast": true,
+ "force_upcast": false,
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
diff --git a/extensions-builtin/Lora/lora_extract.py b/extensions-builtin/Lora/lora_extract.py
index c2e0a275b..1d92f3c6e 100644
--- a/extensions-builtin/Lora/lora_extract.py
+++ b/extensions-builtin/Lora/lora_extract.py
@@ -182,19 +182,6 @@ def make_lora(fn, maxrank, auto_rank, rank_ratio, modules, overwrite):
progress.remove_task(task)
t3 = time.time()
- # TODO: Handle quant for Flux
- # if 'te' in modules and getattr(shared.sd_model, 'transformer', None) is not None:
- # for name, module in shared.sd_model.transformer.named_modules():
- # if "norm" in name and "linear" not in name:
- # continue
- # weights_backup = getattr(module, "network_weights_backup", None)
- # if weights_backup is None:
- # continue
- # module.svdhandler = SVDHandler()
- # module.svdhandler.network_name = "lora_transformer_" + name.replace(".", "_")
- # module.svdhandler.decompose(module.weight, weights_backup)
- # module.svdhandler.findrank(rank, rank_ratio)
-
lora_state_dict = {}
for sub in ['text_encoder', 'text_encoder_2', 'unet', 'transformer']:
submodel = getattr(shared.sd_model, sub, None)
diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py
index 5e6eaef6c..a410a8e3b 100644
--- a/extensions-builtin/Lora/network_lora.py
+++ b/extensions-builtin/Lora/network_lora.py
@@ -22,7 +22,6 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
self.dim = weights.w["lora_down.weight"].shape[0]
def create_module(self, weights, key, none_ok=False):
- from modules.shared import opts
weight = weights.get(key)
if weight is None and none_ok:
return None
@@ -32,7 +31,7 @@ def create_module(self, weights, key, none_ok=False):
if is_linear:
weight = weight.reshape(weight.shape[0], -1)
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif is_conv and key == "lora_down.weight" or key == "dyn_up":
+ elif is_conv and (key == "lora_down.weight" or key == "dyn_up"):
if len(weight.shape) == 2:
weight = weight.reshape(weight.shape[0], -1, 1, 1)
if weight.shape[2] != 1 or weight.shape[3] != 1:
@@ -41,7 +40,7 @@ def create_module(self, weights, key, none_ok=False):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif is_conv and key == "lora_mid.weight":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
- elif is_conv and key == "lora_up.weight" or key == "dyn_down":
+ elif is_conv and (key == "lora_up.weight" or key == "dyn_down"):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else:
raise AssertionError(f'Lora unsupported: layer={self.network_key} type={type(self.sd_module).__name__}')
@@ -49,8 +48,6 @@ def create_module(self, weights, key, none_ok=False):
if weight.shape != module.weight.shape:
weight = weight.reshape(module.weight.shape)
module.weight.copy_(weight)
- if opts.lora_load_gpu:
- module = module.to(device=devices.device, dtype=devices.dtype)
module.weight.requires_grad_(False)
return module
diff --git a/extensions-builtin/Lora/network_overrides.py b/extensions-builtin/Lora/network_overrides.py
index 5334f3c1b..b5c28b718 100644
--- a/extensions-builtin/Lora/network_overrides.py
+++ b/extensions-builtin/Lora/network_overrides.py
@@ -26,7 +26,6 @@
force_models = [ # forced always
'sc',
- # 'sd3',
'kandinsky',
'hunyuandit',
'auraflow',
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index db617ee5b..fd6287c62 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -88,7 +88,7 @@ def assign_network_names_to_compvis_modules(sd_model):
network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
- shared.sd_model.network_layer_mapping = network_layer_mapping
+ sd_model.network_layer_mapping = network_layer_mapping
def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_default_multiplier) -> network.Network:
@@ -141,7 +141,7 @@ def load_network(name, network_on_disk) -> network.Network:
sd = sd_models.read_state_dict(network_on_disk.filename, what='network')
if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict
sd = lora_convert._convert_kohya_flux_lora_to_diffusers(sd) or sd # pylint: disable=protected-access
- assign_network_names_to_compvis_modules(shared.sd_model) # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
+ assign_network_names_to_compvis_modules(shared.sd_model)
keys_failed_to_match = {}
matched_networks = {}
bundle_embeddings = {}
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index ffbef47d9..24723dd7f 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -5,7 +5,7 @@
from network import NetworkOnDisk
from ui_extra_networks_lora import ExtraNetworksPageLora
from extra_networks_lora import ExtraNetworkLora
-from modules import script_callbacks, extra_networks, ui_extra_networks, ui_models # pylint: disable=unused-import
+from modules import script_callbacks, extra_networks, ui_extra_networks, ui_models, shared # pylint: disable=unused-import
re_lora = re.compile(" 3: # TODO python 3.12.4 or higher cause a mess with pydantic
+ if int(sys.version_info.major) == 3 and int(sys.version_info.minor) == 12 and int(sys.version_info.micro) > 3: # TODO install: python 3.12.4 or higher cause a mess with pydantic
log.error(f"Python version incompatible: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro} required 3.12.3 or lower")
if reason is not None:
log.error(reason)
@@ -457,9 +457,9 @@ def check_python(supported_minors=[9, 10, 11, 12], reason=None):
# check diffusers version
def check_diffusers():
- if args.skip_all or args.skip_requirements:
+ if args.skip_all or args.skip_git:
return
- sha = 'b5fd6f13f5434d69d919cc8cedf0b11db664cf06'
+ sha = '6dfaec348780c6153a4cfd03a01972a291d67f82' # diffusers commit hash
pkg = pkg_resources.working_set.by_key.get('diffusers', None)
minor = int(pkg.version.split('.')[1] if pkg is not None else 0)
cur = opts.get('diffusers_version', '') if minor > 0 else ''
@@ -483,6 +483,7 @@ def check_onnx():
def check_torchao():
+ """
if args.skip_all or args.skip_requirements:
return
if installed('torchao', quiet=True):
@@ -492,6 +493,8 @@ def check_torchao():
pip('uninstall --yes torchao', ignore=True, quiet=True, uv=False)
for m in [m for m in sys.modules if m.startswith('torchao')]:
del sys.modules[m]
+ """
+ return
def install_cuda():
@@ -549,7 +552,7 @@ def install_rocm_zluda():
log.info(msg)
torch_command = ''
if sys.platform == "win32":
- # TODO after ROCm for Windows is released
+ # TODO install: enable ROCm for windows when available
if args.device_id is not None:
if os.environ.get('HIP_VISIBLE_DEVICES', None) is not None:
@@ -634,25 +637,30 @@ def install_ipex(torch_command):
os.environ.setdefault('NEOReadDebugKeys', '1')
if os.environ.get("ClDeviceGlobalMemSizeAvailablePercent", None) is None:
os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100')
+ if os.environ.get("PYTORCH_ENABLE_XPU_FALLBACK", None) is None:
+ os.environ.setdefault('PYTORCH_ENABLE_XPU_FALLBACK', '1')
if "linux" in sys.platform:
- torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/')
+ torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.5.1+cxx11.abi torchvision==0.20.1+cxx11.abi intel-extension-for-pytorch==2.5.10+xpu oneccl_bind_pt==2.5.0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/')
# torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision --index-url https://download.pytorch.org/whl/test/xpu') # test wheels are stable previews, significantly slower than IPEX
# os.environ.setdefault('TENSORFLOW_PACKAGE', 'tensorflow==2.15.1 intel-extension-for-tensorflow[xpu]==2.15.0.1')
else:
torch_command = os.environ.get('TORCH_COMMAND', '--pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/xpu') # torchvision doesn't exist on test/stable branch for windows
- install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2024.3.0'), 'openvino', ignore=True)
- install('nncf==2.7.0', 'nncf', ignore=True)
+ install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2024.5.0'), 'openvino', ignore=True)
+ install('nncf==2.7.0', ignore=True, no_deps=True) # requires older pandas
install(os.environ.get('ONNXRUNTIME_PACKAGE', 'onnxruntime-openvino'), 'onnxruntime-openvino', ignore=True)
return torch_command
def install_openvino(torch_command):
- check_python(supported_minors=[8, 9, 10, 11, 12], reason='OpenVINO backend requires Python 3.9, 3.10 or 3.11')
+ check_python(supported_minors=[9, 10, 11, 12], reason='OpenVINO backend requires Python 3.9, 3.10 or 3.11')
log.info('OpenVINO: selected')
- torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.3.1+cpu torchvision==0.18.1+cpu --index-url https://download.pytorch.org/whl/cpu')
- install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2024.3.0'), 'openvino')
+ if sys.platform == 'darwin':
+ torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.3.1 torchvision==0.18.1')
+ else:
+ torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.3.1+cpu torchvision==0.18.1+cpu --index-url https://download.pytorch.org/whl/cpu')
+ install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2024.6.0'), 'openvino')
install(os.environ.get('ONNXRUNTIME_PACKAGE', 'onnxruntime-openvino'), 'onnxruntime-openvino', ignore=True)
- install('nncf==2.12.0', 'nncf')
+ install('nncf==2.14.1', 'nncf')
os.environ.setdefault('PYTORCH_TRACING_MODE', 'TORCHFX')
if os.environ.get("NEOReadDebugKeys", None) is None:
os.environ.setdefault('NEOReadDebugKeys', '1')
@@ -682,7 +690,9 @@ def install_torch_addons():
if opts.get('nncf_compress_weights', False) and not args.use_openvino:
install('nncf==2.7.0', 'nncf')
if opts.get('optimum_quanto_weights', False):
- install('optimum-quanto', 'optimum-quanto')
+ install('optimum-quanto==0.2.6', 'optimum-quanto')
+ if not args.experimental:
+ uninstall('wandb', quiet=True)
if triton_command is not None:
install(triton_command, 'triton', quiet=True)
@@ -727,8 +737,6 @@ def check_torch():
torch_command = install_rocm_zluda()
elif is_ipex_available:
torch_command = install_ipex(torch_command)
- elif allow_openvino:
- torch_command = install_openvino(torch_command)
else:
machine = platform.machine()
@@ -746,6 +754,8 @@ def check_torch():
log.warning('Torch: CPU-only version installed')
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision')
if 'torch' in torch_command and not args.version:
+ if not installed('torch'):
+ log.info(f'Torch: download and install in progress... cmd="{torch_command}"')
install(torch_command, 'torch torchvision', quiet=True)
else:
try:
@@ -999,8 +1009,8 @@ def install_optional():
install('basicsr')
install('gfpgan')
install('clean-fid')
- install('optimum-quanto', ignore=True)
- install('bitsandbytes', ignore=True)
+ install('optimum-quanto=0.2.6', ignore=True)
+ install('bitsandbytes==0.45.0', ignore=True)
install('pynvml', ignore=True)
install('ultralytics==8.3.40', ignore=True)
install('Cython', ignore=True)
@@ -1174,9 +1184,15 @@ def same(ver):
def check_venv():
+ def try_relpath(p):
+ try:
+ return os.path.relpath(p)
+ except ValueError:
+ return p
+
import site
- pkg_path = [os.path.relpath(p) for p in site.getsitepackages() if os.path.exists(p)]
- log.debug(f'Packages: venv={os.path.relpath(sys.prefix)} site={pkg_path}')
+ pkg_path = [try_relpath(p) for p in site.getsitepackages() if os.path.exists(p)]
+ log.debug(f'Packages: venv={try_relpath(sys.prefix)} site={pkg_path}')
for p in pkg_path:
invalid = []
for f in os.listdir(p):
diff --git a/javascript/base.css b/javascript/base.css
index 7daa8b2bd..a30a71845 100644
--- a/javascript/base.css
+++ b/javascript/base.css
@@ -25,7 +25,6 @@
.progressDiv .progress { width: 0%; height: 20px; background: #0060df; color: white; font-weight: bold; line-height: 20px; padding: 0 8px 0 0; text-align: right; overflow: visible; white-space: nowrap; padding: 0 0.5em; }
.livePreview { position: absolute; z-index: 50; background-color: transparent; width: -moz-available; width: -webkit-fill-available; }
.livePreview img { position: absolute; object-fit: contain; width: 100%; height: 100%; }
-.dark .livePreview { background-color: rgb(17 24 39 / var(--tw-bg-opacity)); }
.popup-metadata { color: white; background: #0000; display: inline-block; white-space: pre-wrap; font-size: 0.75em; }
/* fullpage image viewer */
@@ -80,7 +79,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
/* extra networks */
.extra-networks > div { margin: 0; border-bottom: none !important; }
-.extra-networks .second-line { display: flex; width: -moz-available; width: -webkit-fill-available; gap: 0.3em; box-shadow: var(--input-shadow); }
+.extra-networks .second-line { display: flex; width: -moz-available; width: -webkit-fill-available; gap: 0.3em; box-shadow: var(--input-shadow); margin-bottom: 2px; }
.extra-networks .search { flex: 1; }
.extra-networks .description { flex: 3; }
.extra-networks .tab-nav > button { margin-right: 0; height: 24px; padding: 2px 4px 2px 4px; }
@@ -89,7 +88,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
.extra-networks .custom-button { width: 120px; width: 100%; background: none; justify-content: left; text-align: left; padding: 2px 8px 2px 16px; text-indent: -8px; box-shadow: none; line-break: auto; }
.extra-networks .custom-button:hover { background: var(--button-primary-background-fill) }
.extra-networks-tab { padding: 0 !important; }
-.extra-network-subdirs { background: var(--input-background-fill); overflow-x: hidden; overflow-y: auto; min-width: max(15%, 120px); padding-top: 0.5em; margin-top: -4px !important; }
+.extra-network-subdirs { background: var(--input-background-fill); overflow-x: hidden; overflow-y: auto; min-width: max(15%, 120px); padding-top: 0.5em; }
.extra-networks-page { display: flex }
.extra-network-cards { display: flex; flex-wrap: wrap; overflow-y: auto; overflow-x: hidden; align-content: flex-start; width: -moz-available; width: -webkit-fill-available; }
.extra-network-cards .card { height: fit-content; margin: 0 0 0.5em 0.5em; position: relative; scroll-snap-align: start; scroll-margin-top: 0; }
diff --git a/javascript/black-teal-reimagined.css b/javascript/black-teal-reimagined.css
new file mode 100644
index 000000000..28176d247
--- /dev/null
+++ b/javascript/black-teal-reimagined.css
@@ -0,0 +1,1193 @@
+/* Generic HTML Tags */
+@font-face {
+ font-family: 'NotoSans';
+ font-display: swap;
+ font-style: normal;
+ font-weight: 100;
+ src: local('NotoSans'), url('notosans-nerdfont-regular.ttf');
+}
+
+html {
+ scroll-behavior: smooth;
+}
+
+:root,
+.light,
+.dark {
+ --font: 'NotoSans';
+ --font-mono: 'ui-monospace', 'Consolas', monospace;
+ --font-size: 16px;
+
+ /* Primary Colors */
+ --primary-50: #7dffff;
+ --primary-100: #72e8e8;
+ --primary-200: #67d2d2;
+ --primary-300: #5dbcbc;
+ --primary-400: #52a7a7;
+ --primary-500: #489292;
+ --primary-600: #3e7d7d;
+ --primary-700: #356969;
+ --primary-800: #2b5656;
+ --primary-900: #224444;
+ --primary-950: #193232;
+
+ /* Neutral Colors */
+ --neutral-50: #f0f0f0;
+ --neutral-100: #e0e0e0;
+ --neutral-200: #d0d0d0;
+ --neutral-300: #b0b0b0;
+ --neutral-400: #909090;
+ --neutral-500: #707070;
+ --neutral-600: #606060;
+ --neutral-700: #404040;
+ --neutral-800: #303030;
+ --neutral-900: #202020;
+ --neutral-950: #101010;
+
+ /* Highlight and Inactive Colors */
+ --highlight-color: var(--primary-200);
+ --inactive-color: var(--primary-800);
+
+ /* Text Colors */
+ --body-text-color: var(--neutral-100);
+ --body-text-color-subdued: var(--neutral-300);
+
+ /* Background Colors */
+ --background-color: var(--neutral-950);
+ --background-fill-primary: var(--neutral-800);
+ --input-background-fill: var(--neutral-900);
+
+ /* Padding and Borders */
+ --input-padding: 4px;
+ --input-shadow: none;
+ --button-primary-text-color: var(--neutral-100);
+ --button-primary-background-fill: var(--primary-600);
+ --button-primary-background-fill-hover: var(--primary-800);
+ --button-secondary-text-color: var(--neutral-100);
+ --button-secondary-background-fill: var(--neutral-900);
+ --button-secondary-background-fill-hover: var(--neutral-600);
+
+ /* Border Radius */
+ --radius-xs: 2px;
+ --radius-sm: 4px;
+ --radius-md: 6px;
+ --radius-lg: 8px;
+ --radius-xl: 10px;
+ --radius-xxl: 15px;
+ --radius-xxxl: 20px;
+
+ /* Shadows */
+ --shadow-sm: 0 1px 2px rgba(0, 0, 0, 0.1);
+ --shadow-md: 0 2px 4px rgba(0, 0, 0, 0.1);
+ --shadow-lg: 0 4px 8px rgba(0, 0, 0, 0.1);
+ --shadow-xl: 0 8px 16px rgba(0, 0, 0, 0.1);
+
+ /* Animation */
+ --transition: all 0.3s ease;
+
+ /* Scrollbar */
+ --scrollbar-bg: var(--neutral-800);
+ --scrollbar-thumb: var(--highlight-color);
+}
+
+html {
+ font-size: var(--font-size);
+ font-family: var(--font);
+}
+
+body,
+button,
+input,
+select,
+textarea {
+ font-family: var(--font);
+ color: var(--body-text-color);
+ transition: var(--transition);
+}
+
+button {
+ max-width: 400px;
+ white-space: nowrap;
+ padding: 8px 12px;
+ border: none;
+ border-radius: var(--radius-md);
+ background-color: var(--button-primary-background-fill);
+ color: var(--button-primary-text-color);
+ cursor: pointer;
+ box-shadow: var(--shadow-sm);
+ transition: transform 0.2s ease, background-color 0.3s ease;
+}
+
+button:hover {
+ background-color: var(--button-primary-background-fill-hover);
+ transform: scale(1.05);
+}
+
+/* Range Input Styles */
+.slider-container {
+ width: 100%;
+ /* Ensures the container takes full width */
+ max-width: 100%;
+ /* Prevents overflow */
+ padding: 0 10px;
+ /* Adds padding for aesthetic spacing */
+ box-sizing: border-box;
+ /* Ensures padding doesn't affect width */
+}
+
+input[type='range'] {
+ display: block;
+ margin: 0;
+ padding: 0;
+ height: 1em;
+ background-color: transparent;
+ overflow: hidden;
+ cursor: pointer;
+ box-shadow: none;
+ -webkit-appearance: none;
+ opacity: 0.7;
+ appearance: none;
+ width: 100%;
+ /* Makes the slider responsive */
+}
+
+input[type='range'] {
+ opacity: 1;
+}
+
+input[type='range']::-webkit-slider-thumb {
+ -webkit-appearance: none;
+ height: 1em;
+ width: 1em;
+ background-color: var(--highlight-color);
+ border-radius: var(--radius-xs);
+ box-shadow: var(--shadow-md);
+ cursor: pointer;
+ /* Ensures the thumb is clickable */
+}
+
+input[type='range']::-webkit-slider-runnable-track {
+ -webkit-appearance: none;
+ height: 6px;
+ background: var(--input-background-fill);
+ border-radius: var(--radius-md);
+}
+
+input[type='range']::-moz-range-thumb {
+ height: 1em;
+ width: 1em;
+ background-color: var(--highlight-color);
+ border-radius: var(--radius-xs);
+ box-shadow: var(--shadow-md);
+ cursor: pointer;
+ /* Ensures the thumb is clickable */
+}
+
+input[type='range']::-moz-range-track {
+ height: 6px;
+ background: var(--input-background-fill);
+ border-radius: var(--radius-md);
+}
+
+@media (max-width: 768px) {
+ .slider-container {
+ width: 100%;
+ /* Adjust width for smaller screens */
+ }
+
+ .networks-menu,
+ .styles-menu {
+ width: 100%;
+ /* Ensure menus are full width */
+ margin: 0;
+ /* Reset margins for smaller screens */
+ }
+}
+
+/* Scrollbar Styles */
+:root {
+ scrollbar-color: var(--scrollbar-thumb) var(--scrollbar-bg);
+}
+
+::-webkit-scrollbar {
+ width: 12px;
+ height: 12px;
+}
+
+::-webkit-scrollbar-track {
+ background: var(--scrollbar-bg);
+ border-radius: var(--radius-lg);
+}
+
+::-webkit-scrollbar-thumb {
+ background-color: var(--scrollbar-thumb);
+ border-radius: var(--radius-lg);
+ box-shadow: var(--shadow-sm);
+}
+
+/* Tab Navigation Styles */
+.tab-nav {
+ display: flex;
+ /* Use flexbox for layout */
+ justify-content: space-evenly;
+ /* Space out the tabs evenly */
+ align-items: center;
+ /* Center items vertically */
+ background: var(--background-color);
+ /* Background color */
+ border-bottom: 1px dashed var(--highlight-color) !important;
+ /* Bottom border for separation */
+ box-shadow: var(--shadow-md);
+ /* Shadow for depth */
+ margin-bottom: 5px;
+ /* Add some space between the tab nav and the content */
+ padding-bottom: 5px;
+ /* Add space between buttons and border */
+}
+
+/* Individual Tab Styles */
+.tab-nav>button {
+ background: var(--neutral-900);
+ /* No background for default state */
+ color: var(--text-color);
+ /* Text color */
+ border: 1px solid var(--highlight-color);
+ /* No border */
+ border-radius: var(--radius-xxl);
+ /* Rounded corners */
+ cursor: pointer;
+ /* Pointer cursor */
+ transition: background 0.3s ease, color 0.3s ease;
+ /* Smooth transition */
+ padding-top: 5px;
+ padding-bottom: 5px;
+ padding-right: 10px;
+ padding-left: 10px;
+ margin-bottom: 3px;
+}
+
+/* Active Tab Style */
+.tab-nav>button.selected {
+ background: var(--primary-100);
+ /* Highlight active tab */
+ color: var(--background-color);
+ /* Change text color for active tab */
+}
+
+/* Hover State for Tabs */
+.tab-nav>button:hover {
+ background: var(--highlight-color);
+ /* Background on hover */
+ color: var(--background-color);
+ /* Change text color on hover */
+}
+
+/* Responsive Styles */
+@media (max-width: 768px) {
+ .tab-nav {
+ flex-direction: column;
+ /* Stack tabs vertically on smaller screens */
+ align-items: stretch;
+ /* Stretch tabs to full width */
+ }
+
+ .tab-nav>button {
+ width: 100%;
+ /* Full width for buttons */
+ text-align: left;
+ /* Align text to the left */
+ }
+}
+
+/* Quick Settings Panel Styles */
+#quicksettings {
+ background: var(--background-color);
+ /* Background color */
+ box-shadow: var(--shadow-lg);
+ /* Shadow for depth */
+ border-radius: var(--radius-lg);
+ /* Rounded corners */
+ padding: 1em;
+ /* Padding for spacing */
+ z-index: 200;
+ /* Ensure it stays on top */
+}
+
+/* Quick Settings Header */
+#quicksettings .header {
+ font-size: var(--text-lg);
+ /* Font size for header */
+ font-weight: bold;
+ /* Bold text */
+ margin-bottom: 0.5em;
+ /* Space below header */
+}
+
+/* Quick Settings Options */
+#quicksettings .option {
+ display: flex;
+ /* Flexbox for layout */
+ justify-content: space-between;
+ /* Space between label and toggle */
+ align-items: center;
+ /* Center items vertically */
+ padding: 0.5em 0;
+ /* Padding for each option */
+ border-bottom: 1px solid var(--neutral-600);
+ /* Separator line */
+}
+
+/* Option Label Styles */
+#quicksettings .option label {
+ color: var(--text-color);
+ /* Text color */
+}
+
+/* Toggle Switch Styles */
+#quicksettings .option input[type="checkbox"] {
+ cursor: pointer;
+ /* Pointer cursor */
+}
+
+/* Quick Settings Footer */
+#quicksettings .footer {
+ margin-top: 1em;
+ /* Space above footer */
+ text-align: right;
+ /* Align text to the right */
+}
+
+/* Close Button Styles */
+#quicksettings .footer button {
+ background: var(--button-primary-background-fill);
+ /* Button background */
+ color: var(--button-primary-text-color);
+ /* Button text color */
+ border: none;
+ /* No border */
+ border-radius: var(--radius-md);
+ /* Rounded corners */
+ padding: 0.5em 1em;
+ /* Padding for button */
+ cursor: pointer;
+ /* Pointer cursor */
+ transition: 0.3s ease;
+ /* Smooth transition */
+}
+
+/* Close Button Hover State */
+#quicksettings .footer button:hover {
+ background: var(--highlight-color);
+ /* Change background on hover */
+}
+
+/* Responsive Styles */
+@media (max-width: 768px) {
+ #quicksettings {
+ right: 10px;
+ /* Adjust position for smaller screens */
+ width: 90%;
+ /* Full width on smaller screens */
+ }
+}
+
+/* Form Styles */
+div.form, #txt2img_seed_row, #txt2img_subseed_row {
+ border-width: 0;
+ box-shadow: var(--shadow-md);
+ background: var(--background-fill-primary);
+ border-bottom: 3px solid var(--highlight-color);
+ padding: 3px;
+ border-radius: var(--radius-lg);
+ margin: 1px;
+}
+
+/* Image preview styling*/
+#txt2img_gallery {
+ background: var(--background-fill-primary);
+ padding: 5px;
+ margin: 0px;
+}
+
+@keyframes colorChange {
+ 0% {
+ background-color: var(--neutral-800);
+ }
+ 50% {
+ background-color: var(--neutral-700);
+ }
+ 100% {
+ background-color: var(--neutral-800);
+ }
+}
+
+.livePreview {
+ animation: colorChange 3s ease-in-out infinite; /* Adjust the duration as needed */
+ padding: 5px;
+}
+
+/* Gradio Style Classes */
+fieldset .gr-block.gr-box,
+label.block span {
+ padding: 0;
+ margin-top: -4px;
+}
+
+.border-2 {
+ border-width: 0;
+}
+
+.border-b-2 {
+ border-bottom-width: 2px;
+ border-color: var(--highlight-color) !important;
+ padding-bottom: 2px;
+ margin-bottom: 8px;
+}
+
+.bg-white {
+ color: lightyellow;
+ background-color: var(--inactive-color);
+}
+
+.gr-box {
+ border-radius: var(--radius-sm) !important;
+ background-color: var(--neutral-950) !important;
+ box-shadow: var(--shadow-md);
+ border-width: 0;
+ padding: 4px;
+ margin: 12px 0;
+}
+
+.gr-button {
+ font-weight: normal;
+ box-shadow: var(--shadow-sm);
+ font-size: 0.8rem;
+ min-width: 32px;
+ min-height: 32px;
+ padding: 3px;
+ margin: 3px;
+ transition: var(--transition);
+}
+
+.gr-button:hover {
+ background-color: var(--highlight-color);
+}
+
+.gr-check-radio {
+ background-color: var(--inactive-color);
+ border-width: 0;
+ border-radius: var(--radius-lg);
+ box-shadow: var(--shadow-sm);
+}
+
+.gr-check-radio:checked {
+ background-color: var(--highlight-color);
+}
+
+.gr-compact {
+ background-color: var(--background-color);
+}
+
+.gr-form {
+ border-width: 0;
+}
+
+.gr-input {
+ background-color: var(--neutral-800) !important;
+ padding: 4px;
+ margin: 4px;
+ border-radius: var(--radius-md);
+ transition: var(--transition);
+}
+
+.gr-input:hover {
+ background-color: var(--neutral-700);
+}
+
+.gr-input-label {
+ color: lightyellow;
+ border-width: 0;
+ background: transparent;
+ padding: 2px !important;
+}
+
+.gr-panel {
+ background-color: var(--background-color);
+ border-radius: var(--radius-md);
+ box-shadow: var(--shadow-md);
+}
+
+.eta-bar {
+ display: none !important;
+}
+
+.gradio-slider {
+ max-width: 200px;
+}
+
+.gradio-slider input[type="number"] {
+ background: var(--neutral-950);
+ margin-top: 2px;
+}
+
+.gradio-image {
+ height: unset !important;
+}
+
+svg.feather.feather-image,
+.feather .feather-image {
+ display: none;
+}
+
+.gap-2 {
+ padding-top: 8px;
+}
+
+.gr-box>div>div>input.gr-text-input {
+ right: 0;
+ width: 4em;
+ padding: 0;
+ top: -12px;
+ border: none;
+ max-height: 20px;
+}
+
+.output-html {
+ line-height: 1.2 rem;
+ overflow-x: hidden;
+}
+
+.output-html>div {
+ margin-bottom: 8px;
+}
+
+.overflow-hidden .flex .flex-col .relative col .gap-4 {
+ min-width: var(--left-column);
+ max-width: var(--left-column);
+}
+
+.p-2 {
+ padding: 0;
+}
+
+.px-4 {
+ padding-left: 1rem;
+ padding-right: 1rem;
+}
+
+.py-6 {
+ padding-bottom: 0;
+}
+
+.tabs {
+ background-color: var(--background-color);
+}
+
+.block.token-counter span {
+ background-color: var(--input-background-fill) !important;
+ box-shadow: 2px 2px 2px #111;
+ border: none !important;
+ font-size: 0.7rem;
+}
+
+.label-wrap {
+ margin: 8px 0px 4px 0px;
+}
+
+.gradio-button.tool {
+ border: none;
+ background: none;
+ box-shadow: none;
+ filter: hue-rotate(340deg) saturate(0.5);
+}
+
+#tab_extensions table td,
+#tab_extensions table th,
+#tab_config table td,
+#tab_config table th {
+ border: none;
+}
+
+#tab_extensions table tr:hover,
+#tab_config table tr:hover {
+ background-color: var(--neutral-500) !important;
+}
+
+#tab_extensions table,
+#tab_config table {
+ width: 96vw;
+}
+
+#tab_extensions table thead,
+#tab_config table thead {
+ background-color: var(--neutral-700);
+}
+
+#tab_extensions table,
+#tab_config table {
+ background-color: var(--neutral-900);
+}
+
+/* Automatic Style Classes */
+.progressDiv {
+ border-radius: var(--radius-sm) !important;
+ position: fixed;
+ top: 44px;
+ right: 26px;
+ max-width: 262px;
+ height: 48px;
+ z-index: 99;
+ box-shadow: var(--button-shadow);
+}
+
+.progressDiv .progress {
+ border-radius: var(--radius-lg) !important;
+ background: var(--highlight-color);
+ line-height: 3rem;
+ height: 48px;
+}
+
+.gallery-item {
+ box-shadow: none !important;
+}
+
+.performance {
+ color: #888;
+}
+
+.image-buttons {
+ justify-content: center;
+ gap: 0 !important;
+}
+
+.image-buttons>button {
+ max-width: 160px;
+}
+
+.tooltip {
+ background: var(--primary-300);
+ color: black;
+ border: none;
+ border-radius: var(--radius-lg);
+}
+
+#system_row>button,
+#settings_row>button,
+#config_row>button {
+ max-width: 10em;
+}
+
+/* Gradio Elements Overrides */
+#div.gradio-container {
+ overflow-x: hidden;
+}
+
+#img2img_label_copy_to_img2img {
+ font-weight: normal;
+}
+
+#txt2img_styles,
+#img2img_styles,
+#control_styles {
+ padding: 0;
+ margin-top: 2px;
+}
+
+#txt2img_styles_refresh,
+#img2img_styles_refresh,
+#control_styles_refresh {
+ padding: 0;
+ margin-top: 1em;
+}
+
+#img2img_settings {
+ min-width: calc(2 * var(--left-column));
+ max-width: calc(2 * var(--left-column));
+ background-color: var(--neutral-950);
+ padding-top: 16px;
+}
+
+#interrogate,
+#deepbooru {
+ margin: 0 0px 10px 0px;
+ max-width: 80px;
+ max-height: 80px;
+ font-weight: normal;
+ font-size: 0.95em;
+}
+
+#quicksettings .gr-button-tool {
+ font-size: 1.6rem;
+ box-shadow: none;
+ margin-left: -20px;
+ margin-top: -2px;
+ height: 2.4em;
+}
+
+#save-animation {
+ border-radius: var(--radius-sm) !important;
+ margin-bottom: 16px;
+ background-color: var(--neutral-950);
+}
+
+#script_list {
+ padding: 4px;
+ margin-top: 16px;
+ margin-bottom: 8px;
+}
+
+#settings>div.flex-wrap {
+ width: 15em;
+}
+
+#txt2img_cfg_scale {
+ min-width: 200px;
+}
+
+#txt2img_checkboxes,
+#img2img_checkboxes,
+#control_checkboxes {
+ background-color: transparent;
+ margin-bottom: 0.2em;
+}
+
+#extras_upscale {
+ margin-top: 10px;
+}
+
+#txt2img_progress_row>div {
+ min-width: var(--left-column);
+ max-width: var(--left-column);
+}
+
+#txt2img_settings {
+ min-width: var(--left-column);
+ max-width: var(--left-column);
+ background-color: var(--neutral-950);
+}
+
+#pnginfo_html2_info {
+ margin-top: -18px;
+ background-color: var(--input-background-fill);
+ padding: var(--input-padding);
+}
+
+#txt2img_styles_row,
+#img2img_styles_row,
+#control_styles_row {
+ margin-top: -6px;
+}
+
+.block>span {
+ margin-bottom: 0 !important;
+ margin-top: var(--spacing-lg);
+}
+
+/* Extra Networks Container */
+#extra_networks_root {
+ z-index: 100;
+ background: var(--background-color);
+ box-shadow: var(--shadow-md);
+ border-radius: var(--radius-lg);
+ transform: translateX(100%);
+ animation: slideIn 0.5s forwards;
+ overflow: hidden;
+ /* Prevents overflow of content */
+}
+
+@keyframes slideIn {
+ to {
+ transform: translateX(0);
+ }
+}
+
+/* Extra Networks Styles */
+.extra-networks {
+ border-left: 2px solid var(--highlight-color) !important;
+ padding-left: 4px;
+}
+
+.extra-networks .tab-nav>button:hover {
+ background: var(--highlight-color);
+}
+
+/* Network tab search and description important fix, dont remove */
+#txt2img_description,
+#txt2img_extra_search,
+#img2img_description,
+#img2img_extra_search,
+#control_description,
+#control_extra_search {
+ margin-top: 50px;
+}
+
+.extra-networks .buttons>button:hover {
+ background: var(--highlight-color);
+}
+
+/* Network Cards Container */
+.extra-network-cards {
+ display: flex;
+ flex-wrap: wrap;
+ overflow-y: auto;
+ overflow-x: hidden;
+ align-content: flex-start;
+ padding-top: 20px;
+ justify-content: center;
+ width: 100%;
+ /* Ensures it takes full width */
+}
+
+/* Individual Card Styles */
+.extra-network-cards .card {
+ height: fit-content;
+ margin: 0 0 0.5em 0.5em;
+ position: relative;
+ scroll-snap-align: start;
+ scroll-margin-top: 0;
+ background: var(--neutral-800);
+ /* Background for cards */
+ border-radius: var(--radius-md);
+ box-shadow: var(--shadow-md);
+ transition: var(--transition);
+}
+
+/* Overlay Styles */
+.extra-network-cards .card .overlay {
+ z-index: 10;
+ width: 100%;
+ background: none;
+ border-radius: var(--radius-md);
+}
+
+/* Overlay Name Styles */
+.extra-network-cards .card .overlay .name {
+ font-size: var(--text-lg);
+ font-weight: bold;
+ text-shadow: 1px 1px black;
+ color: white;
+ overflow-wrap: anywhere;
+ position: absolute;
+ bottom: 0;
+ padding: 0.2em;
+ z-index: 10;
+}
+
+/* Preview Styles */
+.extra-network-cards .card .preview {
+ box-shadow: var(--button-shadow);
+ min-height: 30px;
+ border-radius: var(--radius-md);
+ z-index: 9999;
+}
+
+/* Hover Effects */
+.extra-network-cards .card:hover {
+ transform: scale(1.3);
+ z-index: 9999; /* Use a high value to ensure it appears on top */
+ transition: transform 0.3s ease, z-index 0s; /* Smooth transition */
+}
+
+.extra-network-cards .card:hover .overlay {
+ z-index: 10000; /* Ensure overlay is also on top */
+}
+
+.extra-network-cards .card:hover .preview {
+ box-shadow: none;
+ filter: grayscale(0%);
+}
+
+/* Tags Styles */
+.extra-network-cards .card .overlay .tags {
+ display: none;
+ overflow-wrap: anywhere;
+ position: absolute;
+ top: 100%;
+ z-index: 20;
+ background: var(--body-background-fill);
+ overflow-x: hidden;
+ overflow-y: auto;
+ max-height: 333px;
+}
+
+/* Individual Tag Styles */
+.extra-network-cards .card .overlay .tag {
+ padding: 2px;
+ margin: 2px;
+ background: rgba(70, 70, 70, 0.60);
+ font-size: var(--text-md);
+ cursor: pointer;
+ display: inline-block;
+}
+
+/* Actions Styles */
+.extra-network-cards .card .actions>span {
+ padding: 4px;
+ font-size: 34px !important;
+}
+
+.extra-network-cards .card .actions {
+ background: none;
+}
+
+.extra-network-cards .card .actions .details {
+ bottom: 50px;
+ background-color: var(--neutral-800);
+}
+
+.extra-network-cards .card .actions>span:hover {
+ color: var(--highlight-color);
+}
+
+/* Version Styles */
+.extra-network-cards .card .version {
+ position: absolute;
+ top: 0;
+ left: 0;
+ padding: 2px;
+ font-weight: bolder;
+ text-shadow: 1px 1px black;
+ text-transform: uppercase;
+ background: gray;
+ opacity: 75%;
+ margin: 4px;
+ line-height: 0.9rem;
+}
+
+/* Hover Actions */
+.extra-network-cards .card:hover .actions {
+ display: block;
+}
+
+.extra-network-cards .card:hover .overlay .tags {
+ display: block;
+}
+
+/* No Preview Card Styles */
+.extra-network-cards .card:has(>img[src*="card-no-preview.png"])::before {
+ content: '';
+ position: absolute;
+ width: 100%;
+ height: 100%;
+ mix-blend-mode: multiply;
+ background-color: var(--data-color);
+}
+
+/* Card List Styles */
+.extra-network-cards .card-list {
+ display: flex;
+ margin: 0.3em;
+ padding: 0.3em;
+ background: var(--input-background-fill);
+ cursor: pointer;
+ border-radius: var(--button-large-radius);
+}
+
+.extra-network-cards .card-list .tag {
+ color: var(--primary-500);
+ margin-left: 0.8em;
+}
+
+/* Correction color picker styling */
+#txt2img_hdr_color_picker label input {
+ width: 100%;
+ height: 100%;
+}
+
+/* loader */
+.splash {
+ position: fixed;
+ top: 0;
+ left: 0;
+ width: 100vw;
+ height: 100vh;
+ z-index: 1000;
+ display: flex;
+ flex-direction: column;
+ align-items: center;
+ justify-content: center;
+ background-color: rgba(0, 0, 0, 0.8);
+}
+
+.motd {
+ margin-top: 1em;
+ color: var(--body-text-color-subdued);
+ font-family: monospace;
+ font-variant: all-petite-caps;
+ font-size: 1.2em;
+}
+
+.splash-img {
+ margin: 0;
+ width: 512px;
+ height: 512px;
+ background-repeat: no-repeat;
+ animation: color 8s infinite alternate, move 3s infinite alternate;
+}
+
+.loading {
+ color: white;
+ position: border-box;
+ top: 85%;
+ font-size: 1.5em;
+}
+
+.loader {
+ width: 100px;
+ height: 100px;
+ border: var(--spacing-md) solid transparent;
+ border-radius: 50%;
+ border-top: var(--spacing-md) solid var(--primary-600);
+ animation: spin 2s linear infinite, pulse 1.5s ease-in-out infinite;
+ position: border-box;
+}
+
+.loader::before,
+.loader::after {
+ content: "";
+ position: absolute;
+ top: 6px;
+ bottom: 6px;
+ left: 6px;
+ right: 6px;
+ border-radius: 50%;
+ border: var(--spacing-md) solid transparent;
+}
+
+.loader::before {
+ border-top-color: var(--primary-900);
+ animation: spin 3s linear infinite;
+}
+
+.loader::after {
+ border-top-color: var(--primary-300);
+ animation: spin 1.5s linear infinite;
+}
+
+@keyframes move {
+ 0% {
+ transform: translateY(0);
+ }
+ 50% {
+ transform: translateY(-10px);
+ }
+ 100% {
+ transform: translateY(0);
+ }
+}
+
+@keyframes spin {
+ from {
+ transform: rotate(0deg);
+ }
+ to {
+ transform: rotate(360deg);
+ }
+}
+
+@keyframes pulse {
+ 0%, 100% {
+ transform: scale(1);
+ }
+ 50% {
+ transform: scale(1.1);
+ }
+}
+
+@keyframes color {
+ 0% {
+ filter: hue-rotate(0deg);
+ }
+ 100% {
+ filter: hue-rotate(360deg);
+ }
+}
+
+/* Token counters styling */
+#txt2img_token_counter, #txt2img_negative_token_counter {
+ display: flex;
+ flex-direction: row;
+ padding-top: 1px;
+ opacity: 0.6;
+ z-index: 99;
+}
+
+#txt2img_prompt_container {
+ margin: 5px;
+ padding: 0px;
+}
+
+#text2img_prompt label, #text2img_neg_prompt label {
+ margin: 0px;
+}
+
+/* Based on Gradio Built-in Dark Theme */
+:root,
+.light,
+.dark {
+ --body-background-fill: var(--background-color);
+ --color-accent-soft: var(--neutral-700);
+ --background-fill-secondary: none;
+ --border-color-accent: var(--background-color);
+ --border-color-primary: var(--background-color);
+ --link-text-color-active: var(--primary-500);
+ --link-text-color: var(--secondary-500);
+ --link-text-color-hover: var(--secondary-400);
+ --link-text-color-visited: var(--secondary-600);
+ --shadow-spread: 1px;
+ --block-background-fill: none;
+ --block-border-color: var(--border-color-primary);
+ --block_border_width: none;
+ --block-info-text-color: var(--body-text-color-subdued);
+ --block-label-background-fill: var(--background-fill-secondary);
+ --block-label-border-color: var(--border-color-primary);
+ --block_label_border_width: none;
+ --block-label-text-color: var(--neutral-200);
+ --block-shadow: none;
+ --block-title-background-fill: none;
+ --block-title-border-color: none;
+ --block-title-border-width: 0px;
+ --block-title-padding: 0;
+ --block-title-radius: none;
+ --block-title-text-size: var(--text-md);
+ --block-title-text-weight: 400;
+ --container-radius: var(--radius-lg);
+ --form-gap-width: 1px;
+ --layout-gap: var(--spacing-xxl);
+ --panel-border-width: 0;
+ --section-header-text-size: var(--text-md);
+ --section-header-text-weight: 400;
+ --checkbox-border-radius: var(--radius-sm);
+ --checkbox-label-gap: 2px;
+ --checkbox-label-padding: var(--spacing-md);
+ --checkbox-label-shadow: var(--shadow-drop);
+ --checkbox-label-text-size: var(--text-md);
+ --checkbox-label-text-weight: 400;
+ --checkbox-check: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e");
+ --radio-circle: url("data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e");
+ --checkbox-shadow: var(--input-shadow);
+ --error-border-width: 1px;
+ --input-border-width: 0;
+ --input-radius: var(--radius-lg);
+ --input-text-size: var(--text-md);
+ --input-text-weight: 400;
+ --prose-text-size: var(--text-md);
+ --prose-text-weight: 400;
+ --prose-header-text-weight: 400;
+ --slider-color: var(--neutral-900);
+ --table-radius: var(--radius-lg);
+ --button-large-padding: 2px 6px;
+ --button-large-radius: var(--radius-lg);
+ --button-large-text-size: var(--text-lg);
+ --button-large-text-weight: 400;
+ --button-shadow: none;
+ --button-shadow-active: none;
+ --button-shadow-hover: none;
+ --button-small-padding: var(--spacing-sm) calc(2 * var(--spacing-sm));
+ --button-small-radius: var(--radius-lg);
+ --button-small-text-size: var(--text-md);
+ --button-small-text-weight: 400;
+ --button-transition: none;
+ --size-9: 64px;
+ --size-14: 64px;
+}
\ No newline at end of file
diff --git a/javascript/black-teal.css b/javascript/black-teal.css
index c6f266c54..2ebf32e96 100644
--- a/javascript/black-teal.css
+++ b/javascript/black-teal.css
@@ -134,7 +134,7 @@ svg.feather.feather-image, .feather .feather-image { display: none }
.gallery-item { box-shadow: none !important; }
.performance { color: #888; }
.extra-networks { border-left: 2px solid var(--highlight-color) !important; padding-left: 4px; }
-.image-buttons { gap: 10px !important; justify-content: center; }
+.image-buttons { justify-content: center; gap: 0 !important; }
.image-buttons > button { max-width: 160px; }
.tooltip { background: var(--primary-300); color: black; border: none; border-radius: var(--radius-lg) }
#system_row > button, #settings_row > button, #config_row > button { max-width: 10em; }
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index 77fe125f3..1d1bcfb24 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -3,19 +3,6 @@ let sortVal = -1;
// helpers
-const requestGet = (url, data, handler) => {
- const xhr = new XMLHttpRequest();
- const args = Object.keys(data).map((k) => `${encodeURIComponent(k)}=${encodeURIComponent(data[k])}`).join('&');
- xhr.open('GET', `${url}?${args}`, true);
- xhr.onreadystatechange = () => {
- if (xhr.readyState === 4) {
- if (xhr.status === 200) handler(JSON.parse(xhr.responseText));
- else console.error(`Request: url=${url} status=${xhr.status} err`);
- }
- };
- xhr.send(JSON.stringify(data));
-};
-
const getENActiveTab = () => {
let tabName = '';
if (gradioApp().getElementById('tab_txt2img').style.display === 'block') tabName = 'txt2img';
@@ -98,7 +85,7 @@ function readCardTags(el, tags) {
}
function readCardDescription(page, item) {
- requestGet('/sd_extra_networks/description', { page, item }, (data) => {
+ xhrGet('/sd_extra_networks/description', { page, item }, (data) => {
const tabname = getENActiveTab();
const description = gradioApp().querySelector(`#${tabname}_description > label > textarea`);
description.value = data?.description?.trim() || '';
@@ -447,6 +434,22 @@ function setupExtraNetworksForTab(tabname) {
};
}
+ // auto-resize networks sidebar
+ const resizeObserver = new ResizeObserver((entries) => {
+ for (const entry of entries) {
+ for (const el of Array.from(gradioApp().getElementById(`${tabname}_extra_tabs`).querySelectorAll('.extra-networks-page'))) {
+ const h = Math.trunc(entry.contentRect.height);
+ if (h <= 0) return;
+ if (window.opts.extra_networks_card_cover === 'sidebar' && window.opts.theme_type === 'Standard') el.style.height = `max(55vh, ${h - 90}px)`;
+ // log(`${tabname} height: ${entry.target.id}=${h} ${el.id}=${el.clientHeight}`);
+ }
+ }
+ });
+ const settingsEl = gradioApp().getElementById(`${tabname}_settings`);
+ const interfaceEl = gradioApp().getElementById(`${tabname}_interface`);
+ if (settingsEl) resizeObserver.observe(settingsEl);
+ if (interfaceEl) resizeObserver.observe(interfaceEl);
+
// en style
if (!en) return;
let lastView;
diff --git a/javascript/gallery.js b/javascript/gallery.js
index 1f3afd148..05e594e4c 100644
--- a/javascript/gallery.js
+++ b/javascript/gallery.js
@@ -94,14 +94,14 @@ async function delayFetchThumb(fn) {
outstanding++;
const res = await fetch(`/sdapi/v1/browser/thumb?file=${encodeURI(fn)}`, { priority: 'low' });
if (!res.ok) {
- console.error(res.statusText);
+ error(`fetchThumb: ${res.statusText}`);
outstanding--;
return undefined;
}
const json = await res.json();
outstanding--;
if (!res || !json || json.error || Object.keys(json).length === 0) {
- if (json.error) console.error(json.error);
+ if (json.error) error(`fetchThumb: ${json.error}`);
return undefined;
}
return json;
diff --git a/javascript/imageMaskFix.js b/javascript/imageMaskFix.js
deleted file mode 100644
index fd37caf90..000000000
--- a/javascript/imageMaskFix.js
+++ /dev/null
@@ -1,38 +0,0 @@
-/**
- * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
- * @see https://github.com/gradio-app/gradio/issues/1721
- */
-function imageMaskResize() {
- const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
- if (!canvases.length) {
- window.removeEventListener('resize', imageMaskResize);
- return;
- }
- const wrapper = canvases[0].closest('.touch-none');
- const previewImage = wrapper.previousElementSibling;
- if (!previewImage.complete) {
- previewImage.addEventListener('load', imageMaskResize);
- return;
- }
- const w = previewImage.width;
- const h = previewImage.height;
- const nw = previewImage.naturalWidth;
- const nh = previewImage.naturalHeight;
- const portrait = nh > nw;
- const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw);
- const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh);
- wrapper.style.width = `${wW}px`;
- wrapper.style.height = `${wH}px`;
- wrapper.style.left = '0px';
- wrapper.style.top = '0px';
- canvases.forEach((c) => {
- c.style.width = '';
- c.style.height = '';
- c.style.maxWidth = '100%';
- c.style.maxHeight = '100%';
- c.style.objectFit = 'contain';
- });
-}
-
-onAfterUiUpdate(imageMaskResize);
-window.addEventListener('resize', imageMaskResize);
diff --git a/javascript/light-teal.css b/javascript/light-teal.css
index 28bf03e6f..174622e52 100644
--- a/javascript/light-teal.css
+++ b/javascript/light-teal.css
@@ -20,9 +20,9 @@
--body-text-color: var(--neutral-800);
--body-text-color-subdued: var(--neutral-600);
--background-color: #FFFFFF;
- --background-fill-primary: var(--neutral-400);
+ --background-fill-primary: var(--neutral-300);
--input-padding: 4px;
- --input-background-fill: var(--neutral-300);
+ --input-background-fill: var(--neutral-200);
--input-shadow: 2px 2px 2px 2px var(--neutral-500);
--button-secondary-text-color: black;
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-200), var(--neutral-500));
@@ -291,8 +291,8 @@ svg.feather.feather-image, .feather .feather-image { display: none }
--slider-color: ;
--stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600));
--table-border-color: var(--neutral-700);
- --table-even-background-fill: #222222;
- --table-odd-background-fill: #333333;
+ --table-even-background-fill: #FFFFFF;
+ --table-odd-background-fill: #CCCCCC;
--table-radius: var(--radius-lg);
--table-row-focus: var(--color-accent-soft);
}
diff --git a/javascript/loader.js b/javascript/loader.js
index f3c7fe60f..8cd4811bf 100644
--- a/javascript/loader.js
+++ b/javascript/loader.js
@@ -20,7 +20,7 @@ async function preloadImages() {
try {
await Promise.all(imagePromises);
} catch (error) {
- console.error('Error preloading images:', error);
+ error(`preloadImages: ${error}`);
}
}
@@ -43,14 +43,16 @@ async function createSplash() {
const motdEl = document.getElementById('motd');
if (motdEl) motdEl.innerHTML = text.replace(/["]+/g, '');
})
- .catch((err) => console.error('getMOTD:', err));
+ .catch((err) => error(`getMOTD: ${err}`));
}
async function removeSplash() {
const splash = document.getElementById('splash');
if (splash) splash.remove();
log('removeSplash');
- log('startupTime', Math.round(performance.now() - appStartTime) / 1000);
+ const t = Math.round(performance.now() - appStartTime) / 1000;
+ log('startupTime', t);
+ xhrPost('/sdapi/v1/log', { message: `ready time=${t}` });
}
window.onload = createSplash;
diff --git a/javascript/logMonitor.js b/javascript/logMonitor.js
index e4fe99a7f..9b915e6da 100644
--- a/javascript/logMonitor.js
+++ b/javascript/logMonitor.js
@@ -2,6 +2,7 @@ let logMonitorEl = null;
let logMonitorStatus = true;
let logWarnings = 0;
let logErrors = 0;
+let logConnected = false;
function dateToStr(ts) {
const dt = new Date(1000 * ts);
@@ -29,8 +30,7 @@ async function logMonitor() {
row.innerHTML = `${dateToStr(l.created)} | ${level}${l.facility} | ${module}${l.msg} | `;
logMonitorEl.appendChild(row);
} catch (e) {
- // console.log('logMonitor', e);
- console.error('logMonitor line', line);
+ error(`logMonitor: ${line}`);
}
};
@@ -46,6 +46,7 @@ async function logMonitor() {
if (logMonitorStatus) setTimeout(logMonitor, opts.logmonitor_refresh_period);
else setTimeout(logMonitor, 10 * 1000); // on failure try to reconnect every 10sec
+
if (!opts.logmonitor_show) return;
logMonitorStatus = false;
if (!logMonitorEl) {
@@ -64,14 +65,20 @@ async function logMonitor() {
const lines = await res.json();
if (logMonitorEl && lines?.length > 0) logMonitorEl.parentElement.parentElement.style.display = opts.logmonitor_show ? 'block' : 'none';
for (const line of lines) addLogLine(line);
+ if (!logConnected) {
+ logConnected = true;
+ xhrPost('/sdapi/v1/log', { debug: 'connected' });
+ }
} else {
- addLogLine(`{ "created": ${Date.now()}, "level":"ERROR", "module":"logMonitor", "facility":"ui", "msg":"Failed to fetch log: ${res?.status} ${res?.statusText}" }`);
+ logConnected = false;
logErrors++;
+ addLogLine(`{ "created": ${Date.now()}, "level":"ERROR", "module":"logMonitor", "facility":"ui", "msg":"Failed to fetch log: ${res?.status} ${res?.statusText}" }`);
}
cleanupLog(atBottom);
} catch (err) {
- addLogLine(`{ "created": ${Date.now()}, "level":"ERROR", "module":"logMonitor", "facility":"ui", "msg":"Failed to fetch log: server unreachable" }`);
+ logConnected = false;
logErrors++;
+ addLogLine(`{ "created": ${Date.now()}, "level":"ERROR", "module":"logMonitor", "facility":"ui", "msg":"Failed to fetch log: server unreachable" }`);
cleanupLog(atBottom);
}
}
diff --git a/javascript/logger.js b/javascript/logger.js
new file mode 100644
index 000000000..8fa812b86
--- /dev/null
+++ b/javascript/logger.js
@@ -0,0 +1,68 @@
+const timeout = 10000;
+
+const log = async (...msg) => {
+ const dt = new Date();
+ const ts = `${dt.getHours().toString().padStart(2, '0')}:${dt.getMinutes().toString().padStart(2, '0')}:${dt.getSeconds().toString().padStart(2, '0')}.${dt.getMilliseconds().toString().padStart(3, '0')}`;
+ if (window.logger) window.logger.innerHTML += window.logPrettyPrint(...msg);
+ console.log(ts, ...msg); // eslint-disable-line no-console
+};
+
+const debug = async (...msg) => {
+ const dt = new Date();
+ const ts = `${dt.getHours().toString().padStart(2, '0')}:${dt.getMinutes().toString().padStart(2, '0')}:${dt.getSeconds().toString().padStart(2, '0')}.${dt.getMilliseconds().toString().padStart(3, '0')}`;
+ if (window.logger) window.logger.innerHTML += window.logPrettyPrint(...msg);
+ console.debug(ts, ...msg); // eslint-disable-line no-console
+};
+
+const error = async (...msg) => {
+ const dt = new Date();
+ const ts = `${dt.getHours().toString().padStart(2, '0')}:${dt.getMinutes().toString().padStart(2, '0')}:${dt.getSeconds().toString().padStart(2, '0')}.${dt.getMilliseconds().toString().padStart(3, '0')}`;
+ if (window.logger) window.logger.innerHTML += window.logPrettyPrint(...msg);
+ console.error(ts, ...msg); // eslint-disable-line no-console
+ // const txt = msg.join(' ');
+ // if (!txt.includes('asctime') && !txt.includes('xhr.')) xhrPost('/sdapi/v1/log', { error: txt }); // eslint-disable-line no-use-before-define
+};
+
+const xhrInternal = (xhrObj, data, handler = undefined, errorHandler = undefined, ignore = false, serverTimeout = timeout) => {
+ const err = (msg) => {
+ if (!ignore) {
+ error(`${msg}: state=${xhrObj.readyState} status=${xhrObj.status} response=${xhrObj.responseText}`);
+ if (errorHandler) errorHandler(xhrObj);
+ }
+ };
+
+ xhrObj.setRequestHeader('Content-Type', 'application/json');
+ xhrObj.timeout = timeout;
+ xhrObj.ontimeout = () => err('xhr.ontimeout');
+ xhrObj.onerror = () => err('xhr.onerror');
+ xhrObj.onabort = () => err('xhr.onabort');
+ xhrObj.onreadystatechange = () => {
+ if (xhrObj.readyState === 4) {
+ if (xhrObj.status === 200) {
+ try {
+ const json = JSON.parse(xhrObj.responseText);
+ if (handler) handler(json);
+ } catch (e) {
+ error(`xhr.onreadystatechange: ${e}`);
+ }
+ } else {
+ err(`xhr.onreadystatechange: state=${xhrObj.readyState} status=${xhrObj.status} response=${xhrObj.responseText}`);
+ }
+ }
+ };
+ const req = JSON.stringify(data);
+ xhrObj.send(req);
+};
+
+const xhrGet = (url, data, handler = undefined, errorHandler = undefined, ignore = false, serverTimeout = timeout) => {
+ const xhr = new XMLHttpRequest();
+ const args = Object.keys(data).map((k) => `${encodeURIComponent(k)}=${encodeURIComponent(data[k])}`).join('&');
+ xhr.open('GET', `${url}?${args}`, true);
+ xhrInternal(xhr, data, handler, errorHandler, ignore, serverTimeout);
+};
+
+function xhrPost(url, data, handler = undefined, errorHandler = undefined, ignore = false, serverTimeout = timeout) {
+ const xhr = new XMLHttpRequest();
+ xhr.open('POST', url, true);
+ xhrInternal(xhr, data, handler, errorHandler, ignore, serverTimeout);
+}
diff --git a/javascript/notification.js b/javascript/notification.js
index 33e8d1c55..c702c90e7 100644
--- a/javascript/notification.js
+++ b/javascript/notification.js
@@ -4,28 +4,32 @@ let lastHeadImg = null;
let notificationButton = null;
async function sendNotification() {
- if (!notificationButton) {
- notificationButton = gradioApp().getElementById('request_notifications');
- if (notificationButton) notificationButton.addEventListener('click', (evt) => Notification.requestPermission(), true);
+ try {
+ if (!notificationButton) {
+ notificationButton = gradioApp().getElementById('request_notifications');
+ if (notificationButton) notificationButton.addEventListener('click', (evt) => Notification.requestPermission(), true);
+ }
+ if (document.hasFocus()) return; // window is in focus so don't send notifications
+ let galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img');
+ if (!galleryPreviews || galleryPreviews.length === 0) galleryPreviews = gradioApp().querySelectorAll('.thumbnail-item > img');
+ if (!galleryPreviews || galleryPreviews.length === 0) return;
+ const headImg = galleryPreviews[0]?.src;
+ if (!headImg || headImg === lastHeadImg || headImg.includes('logo-bg-')) return;
+ const audioNotification = gradioApp().querySelector('#audio_notification audio');
+ if (audioNotification) audioNotification.play();
+ lastHeadImg = headImg;
+ const imgs = new Set(Array.from(galleryPreviews).map((img) => img.src)); // Multiple copies of the images are in the DOM when one is selected
+ const notification = new Notification('SD.Next', {
+ body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
+ icon: headImg,
+ image: headImg,
+ });
+ notification.onclick = () => {
+ parent.focus();
+ this.close();
+ };
+ log('sendNotifications');
+ } catch (e) {
+ error(`sendNotification: ${e}`);
}
- if (document.hasFocus()) return; // window is in focus so don't send notifications
- let galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img');
- if (!galleryPreviews || galleryPreviews.length === 0) galleryPreviews = gradioApp().querySelectorAll('.thumbnail-item > img');
- if (!galleryPreviews || galleryPreviews.length === 0) return;
- const headImg = galleryPreviews[0]?.src;
- if (!headImg || headImg === lastHeadImg || headImg.includes('logo-bg-')) return;
- const audioNotification = gradioApp().querySelector('#audio_notification audio');
- if (audioNotification) audioNotification.play();
- lastHeadImg = headImg;
- const imgs = new Set(Array.from(galleryPreviews).map((img) => img.src)); // Multiple copies of the images are in the DOM when one is selected
- const notification = new Notification('SD.Next', {
- body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
- icon: headImg,
- image: headImg,
- });
- notification.onclick = () => {
- parent.focus();
- this.close();
- };
- log('sendNotifications');
}
diff --git a/javascript/nvml.js b/javascript/nvml.js
index cf0187367..0a82cba1b 100644
--- a/javascript/nvml.js
+++ b/javascript/nvml.js
@@ -1,6 +1,32 @@
let nvmlInterval = null; // eslint-disable-line prefer-const
let nvmlEl = null;
let nvmlTable = null;
+const chartData = { mem: [], load: [] };
+
+async function updateNVMLChart(mem, load) {
+ const maxLen = 120;
+ const colorRangeMap = $.range_map({ // eslint-disable-line no-undef
+ '0:5': '#fffafa',
+ '6:10': '#fff7ed',
+ '11:20': '#fed7aa',
+ '21:30': '#fdba74',
+ '31:40': '#fb923c',
+ '41:50': '#f97316',
+ '51:60': '#ea580c',
+ '61:70': '#c2410c',
+ '71:80': '#9a3412',
+ '81:90': '#7c2d12',
+ '91:100': '#6c2e12',
+ });
+ const sparklineConfigLOAD = { type: 'bar', height: '100px', barWidth: '2px', barSpacing: '1px', chartRangeMin: 0, chartRangeMax: 100, barColor: '#89007D' };
+ const sparklineConfigMEM = { type: 'bar', height: '100px', barWidth: '2px', barSpacing: '1px', chartRangeMin: 0, chartRangeMax: 100, colorMap: colorRangeMap, composite: true };
+ if (chartData.load.length > maxLen) chartData.load.shift();
+ chartData.load.push(load);
+ if (chartData.mem.length > maxLen) chartData.mem.shift();
+ chartData.mem.push(mem);
+ $('#nvmlChart').sparkline(chartData.load, sparklineConfigLOAD); // eslint-disable-line no-undef
+ $('#nvmlChart').sparkline(chartData.mem, sparklineConfigMEM); // eslint-disable-line no-undef
+}
async function updateNVML() {
try {
@@ -35,6 +61,7 @@ async function updateNVML() {
State | ${gpu.state} |
`;
nvmlTbody.innerHTML = rows;
+ updateNVMLChart(gpu.load.memory, gpu.load.gpu);
}
nvmlEl.style.display = 'block';
} catch (e) {
@@ -56,7 +83,10 @@ async function initNVML() {
| |
`;
+ const nvmlChart = document.createElement('div');
+ nvmlChart.id = 'nvmlChart';
nvmlEl.appendChild(nvmlTable);
+ nvmlEl.appendChild(nvmlChart);
gradioApp().appendChild(nvmlEl);
log('initNVML');
}
diff --git a/javascript/progressBar.js b/javascript/progressBar.js
index a9ecb31e9..dfd895f8e 100644
--- a/javascript/progressBar.js
+++ b/javascript/progressBar.js
@@ -1,28 +1,5 @@
let lastState = {};
-function request(url, data, handler, errorHandler) {
- const xhr = new XMLHttpRequest();
- xhr.open('POST', url, true);
- xhr.setRequestHeader('Content-Type', 'application/json');
- xhr.onreadystatechange = () => {
- if (xhr.readyState === 4) {
- if (xhr.status === 200) {
- try {
- const js = JSON.parse(xhr.responseText);
- handler(js);
- } catch (error) {
- console.error(error);
- errorHandler();
- }
- } else {
- errorHandler();
- }
- }
- };
- const js = JSON.stringify(data);
- xhr.send(js);
-}
-
function pad2(x) {
return x < 10 ? `0${x}` : x;
}
@@ -35,8 +12,10 @@ function formatTime(secs) {
function checkPaused(state) {
lastState.paused = state ? !state : !lastState.paused;
- document.getElementById('txt2img_pause').innerText = lastState.paused ? 'Resume' : 'Pause';
- document.getElementById('img2img_pause').innerText = lastState.paused ? 'Resume' : 'Pause';
+ const t_el = document.getElementById('txt2img_pause');
+ const i_el = document.getElementById('img2img_pause');
+ if (t_el) t_el.innerText = lastState.paused ? 'Resume' : 'Pause';
+ if (i_el) i_el.innerText = lastState.paused ? 'Resume' : 'Pause';
}
function setProgress(res) {
@@ -89,28 +68,42 @@ function requestProgress(id_task, progressEl, galleryEl, atEnd = null, onProgres
let img;
const initLivePreview = () => {
- img = new Image();
- if (parentGallery) {
- livePreview = document.createElement('div');
- livePreview.className = 'livePreview';
- parentGallery.insertBefore(livePreview, galleryEl);
- const rect = galleryEl.getBoundingClientRect();
- if (rect.width) {
- livePreview.style.width = `${rect.width}px`;
- livePreview.style.height = `${rect.height}px`;
- }
- img.onload = () => {
- livePreview.appendChild(img);
- if (livePreview.childElementCount > 2) livePreview.removeChild(livePreview.firstElementChild);
- };
+ if (!parentGallery) return;
+ const footers = Array.from(gradioApp().querySelectorAll('.gallery_footer'));
+ for (const footer of footers) {
+ if (footer.id !== 'gallery_footer') footer.style.display = 'none'; // remove all footers
+ }
+ const galleries = Array.from(gradioApp().querySelectorAll('.gallery_main'));
+ for (const gallery of galleries) {
+ if (gallery.id !== 'gallery_gallery') gallery.style.display = 'none'; // remove all footers
}
+
+ livePreview = document.createElement('div');
+ livePreview.className = 'livePreview';
+ parentGallery.insertBefore(livePreview, galleryEl);
+ img = new Image();
+ img.id = 'livePreviewImage';
+ livePreview.appendChild(img);
+ img.onload = () => {
+ img.style.width = `min(100%, max(${img.naturalWidth}px, 512px))`;
+ parentGallery.style.minHeight = `${img.height}px`;
+ };
};
const done = () => {
debug('taskEnd:', id_task);
localStorage.removeItem('task');
setProgress();
- if (parentGallery && livePreview) parentGallery.removeChild(livePreview);
+ const footers = Array.from(gradioApp().querySelectorAll('.gallery_footer'));
+ for (const footer of footers) footer.style.display = 'flex'; // restore all footers
+ const galleries = Array.from(gradioApp().querySelectorAll('.gallery_main'));
+ for (const gallery of galleries) gallery.style.display = 'flex'; // remove all galleries
+ try {
+ if (parentGallery && livePreview) {
+ parentGallery.removeChild(livePreview);
+ parentGallery.style.minHeight = 'unset';
+ }
+ } catch { /* ignore */ }
checkPaused(true);
sendNotification();
if (atEnd) atEnd();
@@ -118,20 +111,32 @@ function requestProgress(id_task, progressEl, galleryEl, atEnd = null, onProgres
const start = (id_task, id_live_preview) => { // eslint-disable-line no-shadow
if (!opts.live_previews_enable || opts.live_preview_refresh_period === 0 || opts.show_progress_every_n_steps === 0) return;
- request('./internal/progress', { id_task, id_live_preview }, (res) => {
+
+ const onProgressHandler = (res) => {
+ // debug('onProgress', res);
lastState = res;
const elapsedFromStart = (new Date() - dateStart) / 1000;
hasStarted |= res.active;
if (res.completed || (!res.active && (hasStarted || once)) || (elapsedFromStart > 30 && !res.queued && res.progress === prevProgress)) {
+ debug('onProgressEnd', res);
done();
return;
}
setProgress(res);
if (res.live_preview && !livePreview) initLivePreview();
- if (res.live_preview && galleryEl) img.src = res.live_preview;
+ if (res.live_preview && galleryEl) {
+ if (img.src !== res.live_preview) img.src = res.live_preview;
+ }
if (onProgress) onProgress(res);
setTimeout(() => start(id_task, id_live_preview), opts.live_preview_refresh_period || 500);
- }, done);
+ };
+
+ const onProgressErrorHandler = (err) => {
+ error(`onProgressError: ${err}`);
+ done();
+ };
+
+ xhrPost('./internal/progress', { id_task, id_live_preview }, onProgressHandler, onProgressErrorHandler, false, 5000);
};
start(id_task, 0);
}
diff --git a/javascript/script.js b/javascript/script.js
index 104567dd7..f943f4626 100644
--- a/javascript/script.js
+++ b/javascript/script.js
@@ -1,17 +1,3 @@
-const log = (...msg) => {
- const dt = new Date();
- const ts = `${dt.getHours().toString().padStart(2, '0')}:${dt.getMinutes().toString().padStart(2, '0')}:${dt.getSeconds().toString().padStart(2, '0')}.${dt.getMilliseconds().toString().padStart(3, '0')}`;
- if (window.logger) window.logger.innerHTML += window.logPrettyPrint(...msg);
- console.log(ts, ...msg); // eslint-disable-line no-console
-};
-
-const debug = (...msg) => {
- const dt = new Date();
- const ts = `${dt.getHours().toString().padStart(2, '0')}:${dt.getMinutes().toString().padStart(2, '0')}:${dt.getSeconds().toString().padStart(2, '0')}.${dt.getMilliseconds().toString().padStart(3, '0')}`;
- if (window.logger) window.logger.innerHTML += window.logPrettyPrint(...msg);
- console.debug(ts, ...msg); // eslint-disable-line no-console
-};
-
async function sleep(ms) {
return new Promise((resolve) => setTimeout(resolve, ms)); // eslint-disable-line no-promise-executor-return
}
@@ -82,7 +68,7 @@ function executeCallbacks(queue, arg) {
try {
callback(arg);
} catch (e) {
- console.error('error running callback', callback, ':', e);
+ error(`executeCallbacks: ${callback} ${e}`);
}
}
}
@@ -139,11 +125,12 @@ document.addEventListener('keydown', (e) => {
let elem;
if (e.key === 'Escape') elem = getUICurrentTabContent().querySelector('button[id$=_interrupt]');
if (e.key === 'Enter' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id$=_generate]');
- if (e.key === 'Backspace' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id$=_reprocess]');
+ if (e.key === 'i' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id$=_reprocess]');
if (e.key === ' ' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id$=_extra_networks_btn]');
+ if (e.key === 'n' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id$=_extra_networks_btn]');
if (e.key === 's' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id^=save_]');
if (e.key === 'Insert' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id^=save_]');
- if (e.key === 'Delete' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id^=delete_]');
+ if (e.key === 'd' && e.ctrlKey) elem = getUICurrentTabContent().querySelector('button[id^=delete_]');
// if (e.key === 'm' && e.ctrlKey) elem = gradioApp().getElementById('setting_sd_model_checkpoint');
if (elem) {
e.preventDefault();
diff --git a/javascript/sdnext.css b/javascript/sdnext.css
index 08fae2eb8..6e33551ef 100644
--- a/javascript/sdnext.css
+++ b/javascript/sdnext.css
@@ -14,9 +14,9 @@ table { overflow-x: auto !important; overflow-y: auto !important; }
td { border-bottom: none !important; padding: 0 0.5em !important; }
tr { border-bottom: none !important; padding: 0 0.5em !important; }
td > div > span { overflow-y: auto; max-height: 3em; overflow-x: hidden; }
-textarea { overflow-y: auto !important; }
+textarea { overflow-y: auto !important; border-radius: 4px !important; }
span { font-size: var(--text-md) !important; }
-button { font-size: var(--text-lg) !important; }
+button { font-size: var(--text-lg) !important; min-width: unset !important; }
input[type='color'] { width: 64px; height: 32px; }
input::-webkit-outer-spin-button, input::-webkit-inner-spin-button { margin-left: 4px; }
@@ -30,6 +30,19 @@ input::-webkit-outer-spin-button, input::-webkit-inner-spin-button { margin-left
.hidden { display: none; }
.tabitem { padding: 0 !important; }
+/* gradio image/canvas elements */
+.image-container { overflow: auto; }
+/*
+.gradio-image { min-height: fit-content; }
+.gradio-image img { object-fit: contain; }
+*/
+/*
+.gradio-image { min-height: 200px !important; }
+.image-container { height: unset !important; }
+.control-image { height: unset !important; }
+#img2img_sketch, #img2maskimg, #inpaint_sketch { overflow: overlay !important; resize: auto; background: var(--panel-background-fill); z-index: 5; }
+*/
+
/* color elements */
.gradio-dropdown, .block.gradio-slider, .block.gradio-checkbox, .block.gradio-textbox, .block.gradio-radio, .block.gradio-checkboxgroup, .block.gradio-number, .block.gradio-colorpicker { border-width: 0 !important; box-shadow: none !important;}
.gradio-accordion { padding-top: var(--spacing-md) !important; padding-right: 0 !important; padding-bottom: 0 !important; color: var(--body-text-color); }
@@ -83,13 +96,12 @@ button.custom-button { border-radius: var(--button-large-radius); padding: var(-
.block.token-counter div{ display: inline; }
.block.token-counter span{ padding: 0.1em 0.75em; }
.performance { font-size: var(--text-xs); color: #444; }
-.performance p { display: inline-block; color: var(--body-text-color-subdued) !important }
+.performance p { display: inline-block; color: var(--primary-500) !important }
.performance .time { margin-right: 0; }
.thumbnails { background: var(--body-background-fill); }
-.control-image { height: calc(100vw/3) !important; }
.prompt textarea { resize: vertical; }
+.grid-wrap { overflow-y: auto !important; }
#control_results { margin: 0; padding: 0; }
-#control_gallery { height: calc(100vw/3 + 60px); }
#txt2img_gallery, #img2img_gallery { height: 50vh; }
#control-result { background: var(--button-secondary-background-fill); padding: 0.2em; }
#control-inputs { margin-top: 1em; }
@@ -105,7 +117,6 @@ button.custom-button { border-radius: var(--button-large-radius); padding: var(-
#txt2img_prompt, #txt2img_neg_prompt, #img2img_prompt, #img2img_neg_prompt, #control_prompt, #control_neg_prompt { display: contents; }
#txt2img_actions_column, #img2img_actions_column, #control_actions { flex-flow: wrap; justify-content: space-between; }
-
.interrogate-clip { position: absolute; right: 6em; top: 8px; max-width: fit-content; background: none !important; z-index: 50; }
.interrogate-blip { position: absolute; right: 4em; top: 8px; max-width: fit-content; background: none !important; z-index: 50; }
.interrogate-col { min-width: 0 !important; max-width: fit-content; margin-right: var(--spacing-xxl); }
@@ -118,11 +129,9 @@ div#extras_scale_to_tab div.form { flex-direction: row; }
#img2img_unused_scale_by_slider { visibility: hidden; width: 0.5em; max-width: 0.5em; min-width: 0.5em; }
.inactive{ opacity: 0.5; }
div#extras_scale_to_tab div.form { flex-direction: row; }
-#mode_img2img .gradio-image>div.fixed-height, #mode_img2img .gradio-image>div.fixed-height img{ height: 480px !important; max-height: 480px !important; min-height: 480px !important; }
-#img2img_sketch, #img2maskimg, #inpaint_sketch { overflow: overlay !important; resize: auto; background: var(--panel-background-fill); z-index: 5; }
.image-buttons button { min-width: auto; }
.infotext { overflow-wrap: break-word; line-height: 1.5em; font-size: 0.95em !important; }
-.infotext > p { padding-left: 1em; text-indent: -1em; white-space: pre-wrap; color: var(--block-info-text-color) !important; }
+.infotext > p { white-space: pre-wrap; color: var(--block-info-text-color) !important; }
.tooltip { display: block; position: fixed; top: 1em; right: 1em; padding: 0.5em; background: var(--input-background-fill); color: var(--body-text-color); border: 1pt solid var(--button-primary-border-color);
width: 22em; min-height: 1.3em; font-size: var(--text-xs); transition: opacity 0.2s ease-in; pointer-events: none; opacity: 0; z-index: 999; }
.tooltip-show { opacity: 0.9; }
@@ -140,29 +149,30 @@ div#extras_scale_to_tab div.form { flex-direction: row; }
#settings>div.tab-content { flex: 10 0 75%; display: grid; }
#settings>div.tab-content>div { border: none; padding: 0; }
#settings>div.tab-content>div>div>div>div>div { flex-direction: unset; }
-#settings>div.tab-nav { display: grid; grid-template-columns: repeat(auto-fill, .5em minmax(10em, 1fr)); flex: 1 0 auto; width: 12em; align-self: flex-start; gap: var(--spacing-xxl); }
+#settings>div.tab-nav { display: grid; grid-template-columns: repeat(auto-fill, .5em minmax(10em, 1fr)); flex: 1 0 auto; width: 12em; align-self: flex-start; gap: 8px; }
#settings>div.tab-nav button { display: block; border: none; text-align: left; white-space: initial; padding: 0; }
#settings>div.tab-nav>#settings_show_all_pages { padding: var(--size-2) var(--size-4); }
#settings .block.gradio-checkbox { margin: 0; width: auto; }
#settings .dirtyable { gap: .5em; }
#settings .dirtyable.hidden { display: none; }
-#settings .modification-indicator { height: 1.2em; border-radius: 1em !important; padding: 0; width: 0; margin-right: 0.5em; }
+#settings .modification-indicator { height: 1.2em; border-radius: 1em !important; padding: 0; width: 0; margin-right: 0.5em; border-left: inset; }
#settings .modification-indicator:disabled { visibility: hidden; }
#settings .modification-indicator.saved { background: var(--color-accent-soft); width: var(--spacing-sm); }
#settings .modification-indicator.changed { background: var(--color-accent); width: var(--spacing-sm); }
#settings .modification-indicator.changed.unsaved { background-image: linear-gradient(var(--color-accent) 25%, var(--color-accent-soft) 75%); width: var(--spacing-sm); }
#settings_result { margin: 0 1.2em; }
+#tab_settings .gradio-slider, #tab_settings .gradio-dropdown { width: 300px !important; max-width: 300px; }
+#tab_settings textarea { max-width: 500px; }
.licenses { display: block !important; }
/* live preview */
.progressDiv { position: relative; height: 20px; background: #b4c0cc; margin-bottom: -3px; }
.dark .progressDiv { background: #424c5b; }
.progressDiv .progress { width: 0%; height: 20px; background: #0060df; color: white; font-weight: bold; line-height: 20px; padding: 0 8px 0 0; text-align: right; overflow: visible; white-space: nowrap; padding: 0 0.5em; }
-.livePreview { position: absolute; z-index: 50; background-color: transparent; width: -moz-available; width: -webkit-fill-available; }
-.livePreview img { position: absolute; object-fit: contain; width: 100%; height: 100%; }
-.dark .livePreview { background-color: rgb(17 24 39 / var(--tw-bg-opacity)); }
+.livePreview { position: absolute; z-index: 50; width: -moz-available; width: -webkit-fill-available; height: 100%; background-color: var(--background-color); }
+.livePreview img { object-fit: contain; width: 100%; justify-self: center; }
.popup-metadata { color: white; background: #0000; display: inline-block; white-space: pre-wrap; font-size: var(--text-xxs); }
-
+.generating { animation: unset !important; border: unset !important; }
/* fullpage image viewer */
#lightboxModal { display: none; position: fixed; z-index: 1001; left: 0; top: 0; width: 100%; height: 100%; overflow: hidden; background-color: rgba(20, 20, 20, 0.75); backdrop-filter: blur(6px);
user-select: none; -webkit-user-select: none; flex-direction: row; font-family: 'NotoSans';}
@@ -207,7 +217,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
.extra_networks_root { width: 0; position: absolute; height: auto; right: 0; top: 13em; z-index: 100; } /* default is sidebar view */
.extra-networks { background: var(--background-color); padding: var(--block-label-padding); }
.extra-networks > div { margin: 0; border-bottom: none !important; gap: 0.3em 0; }
-.extra-networks .second-line { display: flex; width: -moz-available; width: -webkit-fill-available; gap: 0.3em; box-shadow: var(--input-shadow); }
+.extra-networks .second-line { display: flex; width: -moz-available; width: -webkit-fill-available; gap: 0.3em; box-shadow: var(--input-shadow); margin-bottom: 2px; }
.extra-networks .search { flex: 1; height: 4em; }
.extra-networks .description { flex: 3; }
.extra-networks .tab-nav>button { margin-right: 0; height: 24px; padding: 2px 4px 2px 4px; }
@@ -216,7 +226,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
.extra-networks .custom-button { width: 120px; width: 100%; background: none; justify-content: left; text-align: left; padding: 3px 3px 3px 12px; text-indent: -6px; box-shadow: none; line-break: auto; }
.extra-networks .custom-button:hover { background: var(--button-primary-background-fill) }
.extra-networks-tab { padding: 0 !important; }
-.extra-network-subdirs { background: var(--input-background-fill); overflow-x: hidden; overflow-y: auto; min-width: max(15%, 120px); padding-top: 0.5em; margin-top: -4px !important; }
+.extra-network-subdirs { background: var(--input-background-fill); overflow-x: hidden; overflow-y: auto; min-width: max(15%, 120px); padding-top: 0.5em; border-radius: 4px; }
.extra-networks-page { display: flex }
.extra-network-cards { display: flex; flex-wrap: wrap; overflow-y: auto; overflow-x: hidden; align-content: flex-start; width: -moz-available; width: -webkit-fill-available; }
.extra-network-cards .card { height: fit-content; margin: 0 0 0.5em 0.5em; position: relative; scroll-snap-align: start; scroll-margin-top: 0; }
@@ -380,8 +390,6 @@ div:has(>#tab-gallery-folders) { flex-grow: 0 !important; background-color: var(
#img2img_actions_column { display: flex; min-width: fit-content !important; flex-direction: row;justify-content: space-evenly; align-items: center;}
#txt2img_generate_box, #img2img_generate_box, #txt2img_enqueue_wrapper,#img2img_enqueue_wrapper {display: flex;flex-direction: column;height: 4em !important;align-items: stretch;justify-content: space-evenly;}
#img2img_interface, #img2img_results, #img2img_footer p { text-wrap: wrap; min-width: 100% !important; max-width: 100% !important;} /* maintain single column for from image operations on larger mobile devices */
- #img2img_sketch, #img2maskimg, #inpaint_sketch {display: flex; overflow: auto !important; resize: none !important; } /* fix inpaint image display being too large for mobile displays */
- #img2maskimg canvas { width: auto !important; max-height: 100% !important; height: auto !important; }
#txt2img_sampler, #txt2img_batch, #txt2img_seed_group, #txt2img_advanced, #txt2img_second_pass, #img2img_sampling_group, #img2img_resize_group, #img2img_batch_group, #img2img_seed_group, #img2img_denoise_group, #img2img_advanced_group { width: 100% !important; } /* fix from text/image UI elements to prevent them from moving around within the UI */
#img2img_resize_group .gradio-radio>div { display: flex; flex-direction: column; width: unset !important; }
#inpaint_controls div { display:flex;flex-direction: row;}
diff --git a/javascript/ui.js b/javascript/ui.js
index 8808f1c8b..3e3f14390 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -28,7 +28,7 @@ function clip_gallery_urls(gallery) {
const files = gallery.map((v) => v.data);
navigator.clipboard.writeText(JSON.stringify(files)).then(
() => log('clipboard:', files),
- (err) => console.error('clipboard:', files, err),
+ (err) => error(`clipboard: ${files} ${err}`),
);
}
@@ -139,7 +139,7 @@ function switch_to_inpaint(...args) {
return Array.from(arguments);
}
-function switch_to_inpaint_sketch(...args) {
+function switch_to_composite(...args) {
switchToTab('Image');
switch_to_img2img_tab(3);
return Array.from(arguments);
@@ -493,9 +493,9 @@ function previewTheme() {
el.src = `/file=html/${name}.jpg`;
}
})
- .catch((e) => console.error('previewTheme:', e));
+ .catch((e) => error(`previewTheme: ${e}`));
})
- .catch((e) => console.error('previewTheme:', e));
+ .catch((e) => error(`previewTheme: ${e}`));
}
async function browseFolder() {
diff --git a/launch.py b/launch.py
index f944a7e54..e00da58c7 100755
--- a/launch.py
+++ b/launch.py
@@ -55,9 +55,11 @@ def get_custom_args():
if 'PS1' in env:
del env['PS1']
installer.log.trace(f'Environment: {installer.print_dict(env)}')
- else:
- env = [f'{k}={v}' for k, v in os.environ.items() if k.startswith('SD_')]
- installer.log.debug(f'Env flags: {env}')
+ env = [f'{k}={v}' for k, v in os.environ.items() if k.startswith('SD_')]
+ installer.log.debug(f'Env flags: {env}')
+ ldd = os.environ.get('LD_PRELOAD', None)
+ if ldd is not None:
+ installer.log.debug(f'Linker flags: "{ldd}"')
@lru_cache()
diff --git a/models/Reference/Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg b/models/Reference/Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg
new file mode 100644
index 000000000..654f85403
Binary files /dev/null and b/models/Reference/Efficient-Large-Model--Sana_1600M_1024px_diffusers.jpg differ
diff --git a/modules/api/api.py b/modules/api/api.py
index f8346995d..b958085ea 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -35,7 +35,8 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
# server api
self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str)
- self.add_api_route("/sdapi/v1/log", server.get_log_buffer, methods=["GET"], response_model=List[str])
+ self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=List[str])
+ self.add_api_route("/sdapi/v1/log", server.post_log, methods=["POST"])
self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"])
self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"])
self.add_api_route("/sdapi/v1/status", server.get_status, methods=["GET"], response_model=models.ResStatus)
@@ -90,6 +91,11 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.add_api_route("/sdapi/v1/history", endpoints.get_history, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/history", endpoints.post_history, methods=["POST"], response_model=int)
+ # lora api
+ if shared.native:
+ self.add_api_route("/sdapi/v1/loras", endpoints.get_loras, methods=["GET"], response_model=List[dict])
+ self.add_api_route("/sdapi/v1/refresh-loras", endpoints.post_refresh_loras, methods=["POST"])
+
# gallery api
gallery.register_api(app)
diff --git a/modules/api/control.py b/modules/api/control.py
index 29c5a77f1..411f71ff8 100644
--- a/modules/api/control.py
+++ b/modules/api/control.py
@@ -159,6 +159,8 @@ def post_control(self, req: ReqControl):
output_images = []
output_processed = []
output_info = ''
+ # TODO control: support scripts via api
+ # init script args, call scripts.script_control.run, call scripts.script_control.after
run.control_set({ 'do_not_save_grid': not req.save_images, 'do_not_save_samples': not req.save_images, **self.prepare_ip_adapter(req) })
run.control_set(getattr(req, "extra", {}))
res = run.control_run(**args)
diff --git a/modules/api/endpoints.py b/modules/api/endpoints.py
index 61993db84..1c56b7171 100644
--- a/modules/api/endpoints.py
+++ b/modules/api/endpoints.py
@@ -40,6 +40,12 @@ def convert_embeddings(embeddings):
return {"loaded": convert_embeddings(db.word_embeddings), "skipped": convert_embeddings(db.skipped_embeddings)}
+def get_loras():
+ from modules.lora import network, networks
+ def create_lora_json(obj: network.NetworkOnDisk):
+ return { "name": obj.name, "alias": obj.alias, "path": obj.filename, "metadata": obj.metadata }
+ return [create_lora_json(obj) for obj in networks.available_networks.values()]
+
def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, filename: Optional[str] = None, title: Optional[str] = None, fullname: Optional[str] = None, hash: Optional[str] = None): # pylint: disable=redefined-builtin
res = []
for pg in shared.extra_networks:
@@ -126,6 +132,10 @@ def post_refresh_checkpoints():
def post_refresh_vae():
return shared.refresh_vaes()
+def post_refresh_loras():
+ from modules.lora import networks
+ return networks.list_available_networks()
+
def get_extensions_list():
from modules import extensions
extensions.list_extensions()
diff --git a/modules/api/generate.py b/modules/api/generate.py
index b8ee645a4..9b409a14b 100644
--- a/modules/api/generate.py
+++ b/modules/api/generate.py
@@ -116,6 +116,8 @@ def post_text2img(self, txt2imgreq: models.ReqTxt2Img):
processed = scripts.scripts_txt2img.run(p, *script_args) # Need to pass args as list here
else:
processed = process_images(p)
+ processed = scripts.scripts_txt2img.after(p, processed, *script_args)
+ p.close()
shared.state.end(api=False)
if processed is None or processed.images is None or len(processed.images) == 0:
b64images = []
@@ -166,6 +168,8 @@ def post_img2img(self, img2imgreq: models.ReqImg2Img):
processed = scripts.scripts_img2img.run(p, *script_args) # Need to pass args as list here
else:
processed = process_images(p)
+ processed = scripts.scripts_img2img.after(p, processed, *script_args)
+ p.close()
shared.state.end(api=False)
if processed is None or processed.images is None or len(processed.images) == 0:
b64images = []
diff --git a/modules/api/models.py b/modules/api/models.py
index e68ebf081..39bcbe383 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -286,10 +286,16 @@ class ResImageInfo(BaseModel):
items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had")
parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields")
-class ReqLog(BaseModel):
+class ReqGetLog(BaseModel):
lines: int = Field(default=100, title="Lines", description="How many lines to return")
clear: bool = Field(default=False, title="Clear", description="Should the log be cleared after returning the lines?")
+
+class ReqPostLog(BaseModel):
+ message: Optional[str] = Field(title="Message", description="The info message to log")
+ debug: Optional[str] = Field(title="Debug message", description="The debug message to log")
+ error: Optional[str] = Field(title="Error message", description="The error message to log")
+
class ReqProgress(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
diff --git a/modules/api/server.py b/modules/api/server.py
index 939e19c86..dabbe634c 100644
--- a/modules/api/server.py
+++ b/modules/api/server.py
@@ -37,12 +37,22 @@ def get_platform():
from modules.loader import get_packages as loader_get_packages
return { **installer_get_platform(), **loader_get_packages() }
-def get_log_buffer(req: models.ReqLog = Depends()):
+def get_log(req: models.ReqGetLog = Depends()):
lines = shared.log.buffer[:req.lines] if req.lines > 0 else shared.log.buffer.copy()
if req.clear:
shared.log.buffer.clear()
return lines
+def post_log(req: models.ReqPostLog):
+ if req.message is not None:
+ shared.log.info(f'UI: {req.message}')
+ if req.debug is not None:
+ shared.log.debug(f'UI: {req.debug}')
+ if req.error is not None:
+ shared.log.error(f'UI: {req.error}')
+ return {}
+
+
def get_config():
options = {}
for k in shared.opts.data.keys():
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 4065d13d9..11ba7b56e 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -2,7 +2,7 @@
import threading
import time
import cProfile
-from modules import shared, progress, errors
+from modules import shared, progress, errors, timer
queue_lock = threading.Lock()
@@ -73,15 +73,20 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_m}m {elapsed_s:.2f}s" if elapsed_m > 0 else f"{elapsed_s:.2f}s"
- vram_html = ''
+ summary = timer.process.summary(min_time=0.25, total=False).replace('=', ' ')
+ gpu = ''
+ cpu = ''
if not shared.mem_mon.disabled:
vram = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.read().items()}
- if vram.get('active_peak', 0) > 0:
- vram_html = " | "
- vram_html += f"GPU active {max(vram['active_peak'], vram['reserved_peak'])} MB reserved {vram['reserved']} | used {vram['used']} MB free {vram['free']} MB total {vram['total']} MB"
- vram_html += f" | retries {vram['retries']} oom {vram['oom']}" if vram.get('retries', 0) > 0 or vram.get('oom', 0) > 0 else ''
- vram_html += "
"
+ peak = max(vram['active_peak'], vram['reserved_peak'], vram['used'])
+ used = round(100.0 * peak / vram['total']) if vram['total'] > 0 else 0
+ if used > 0:
+ gpu += f"| GPU {peak} MB {used}%"
+ gpu += f" | retries {vram['retries']} oom {vram['oom']}" if vram.get('retries', 0) > 0 or vram.get('oom', 0) > 0 else ''
+ ram = shared.ram_stats()
+ if ram['used'] > 0:
+ cpu += f"| RAM {ram['used']} GB {round(100.0 * ram['used'] / ram['total'])}%"
if isinstance(res, list):
- res[-1] += f""
+ res[-1] += f""
return tuple(res)
return f
diff --git a/modules/consistory/consistory_unet_sdxl.py b/modules/consistory/consistory_unet_sdxl.py
index 4dd9b42d2..940b4ba01 100644
--- a/modules/consistory/consistory_unet_sdxl.py
+++ b/modules/consistory/consistory_unet_sdxl.py
@@ -916,7 +916,6 @@ def forward(
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
diff --git a/modules/control/run.py b/modules/control/run.py
index 5d6343c98..8cecb93af 100644
--- a/modules/control/run.py
+++ b/modules/control/run.py
@@ -7,6 +7,7 @@
from modules.control import util # helper functions
from modules.control import unit # control units
from modules.control import processors # image preprocessors
+from modules.control import tile # tiling module
from modules.control.units import controlnet # lllyasviel ControlNet
from modules.control.units import xs # VisLearn ControlNet-XS
from modules.control.units import lite # Kohya ControlLLLite
@@ -44,6 +45,167 @@ def terminate(msg):
return msg
+def set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits):
+ global pipe, instance # pylint: disable=global-statement
+ pipe = None
+ if has_models:
+ p.ops.append('control')
+ p.extra_generation_params["Control type"] = unit_type # overriden later with pretty-print
+ p.extra_generation_params["Control model"] = ';'.join([(m.model_id or '') for m in active_model if m.model is not None])
+ p.extra_generation_params["Control conditioning"] = control_conditioning if isinstance(control_conditioning, list) else [control_conditioning]
+ p.extra_generation_params['Control start'] = control_guidance_start if isinstance(control_guidance_start, list) else [control_guidance_start]
+ p.extra_generation_params['Control end'] = control_guidance_end if isinstance(control_guidance_end, list) else [control_guidance_end]
+ p.extra_generation_params["Control conditioning"] = ';'.join([str(c) for c in p.extra_generation_params["Control conditioning"]])
+ p.extra_generation_params['Control start'] = ';'.join([str(c) for c in p.extra_generation_params['Control start']])
+ p.extra_generation_params['Control end'] = ';'.join([str(c) for c in p.extra_generation_params['Control end']])
+ if unit_type == 't2i adapter' and has_models:
+ p.extra_generation_params["Control type"] = 'T2I-Adapter'
+ p.task_args['adapter_conditioning_scale'] = control_conditioning
+ instance = t2iadapter.AdapterPipeline(selected_models, shared.sd_model)
+ pipe = instance.pipeline
+ if inits is not None:
+ shared.log.warning('Control: T2I-Adapter does not support separate init image')
+ elif unit_type == 'controlnet' and has_models:
+ p.extra_generation_params["Control type"] = 'ControlNet'
+ p.task_args['controlnet_conditioning_scale'] = control_conditioning
+ p.task_args['control_guidance_start'] = control_guidance_start
+ p.task_args['control_guidance_end'] = control_guidance_end
+ p.task_args['guess_mode'] = p.guess_mode
+ instance = controlnet.ControlNetPipeline(selected_models, shared.sd_model, p=p)
+ pipe = instance.pipeline
+ elif unit_type == 'xs' and has_models:
+ p.extra_generation_params["Control type"] = 'ControlNet-XS'
+ p.controlnet_conditioning_scale = control_conditioning
+ p.control_guidance_start = control_guidance_start
+ p.control_guidance_end = control_guidance_end
+ instance = xs.ControlNetXSPipeline(selected_models, shared.sd_model)
+ pipe = instance.pipeline
+ if inits is not None:
+ shared.log.warning('Control: ControlNet-XS does not support separate init image')
+ elif unit_type == 'lite' and has_models:
+ p.extra_generation_params["Control type"] = 'ControlLLLite'
+ p.controlnet_conditioning_scale = control_conditioning
+ instance = lite.ControlLLitePipeline(shared.sd_model)
+ pipe = instance.pipeline
+ if inits is not None:
+ shared.log.warning('Control: ControlLLLite does not support separate init image')
+ elif unit_type == 'reference' and has_models:
+ p.extra_generation_params["Control type"] = 'Reference'
+ p.extra_generation_params["Control attention"] = p.attention
+ p.task_args['reference_attn'] = 'Attention' in p.attention
+ p.task_args['reference_adain'] = 'Adain' in p.attention
+ p.task_args['attention_auto_machine_weight'] = p.query_weight
+ p.task_args['gn_auto_machine_weight'] = p.adain_weight
+ p.task_args['style_fidelity'] = p.fidelity
+ instance = reference.ReferencePipeline(shared.sd_model)
+ pipe = instance.pipeline
+ if inits is not None:
+ shared.log.warning('Control: ControlNet-XS does not support separate init image')
+ else: # run in txt2img/img2img mode
+ if len(active_strength) > 0:
+ p.strength = active_strength[0]
+ pipe = shared.sd_model
+ instance = None
+ debug(f'Control: run type={unit_type} models={has_models} pipe={pipe.__class__.__name__ if pipe is not None else None}')
+ return pipe
+
+
+def check_active(p, unit_type, units):
+ active_process: List[processors.Processor] = [] # all active preprocessors
+ active_model: List[Union[controlnet.ControlNet, xs.ControlNetXS, t2iadapter.Adapter]] = [] # all active models
+ active_strength: List[float] = [] # strength factors for all active models
+ active_start: List[float] = [] # start step for all active models
+ active_end: List[float] = [] # end step for all active models
+ num_units = 0
+ for u in units:
+ if u.type != unit_type:
+ continue
+ num_units += 1
+ debug(f'Control unit: i={num_units} type={u.type} enabled={u.enabled}')
+ if not u.enabled:
+ if u.controlnet is not None and u.controlnet.model is not None:
+ debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.cpu}')
+ sd_models.move_model(u.controlnet.model, devices.cpu)
+ continue
+ if u.controlnet is not None and u.controlnet.model is not None:
+ debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.device}')
+ sd_models.move_model(u.controlnet.model, devices.device)
+ if unit_type == 't2i adapter' and u.adapter.model is not None:
+ active_process.append(u.process)
+ active_model.append(u.adapter)
+ active_strength.append(float(u.strength))
+ p.adapter_conditioning_factor = u.factor
+ shared.log.debug(f'Control T2I-Adapter unit: i={num_units} process="{u.process.processor_id}" model="{u.adapter.model_id}" strength={u.strength} factor={u.factor}')
+ elif unit_type == 'controlnet' and u.controlnet.model is not None:
+ active_process.append(u.process)
+ active_model.append(u.controlnet)
+ active_strength.append(float(u.strength))
+ active_start.append(float(u.start))
+ active_end.append(float(u.end))
+ p.guess_mode = u.guess
+ if isinstance(u.mode, str):
+ p.control_mode = u.choices.index(u.mode) if u.mode in u.choices else 0
+ p.is_tile = p.is_tile or 'tile' in u.mode.lower()
+ p.control_tile = u.tile
+ p.extra_generation_params["Control mode"] = u.mode
+ shared.log.debug(f'Control ControlNet unit: i={num_units} process="{u.process.processor_id}" model="{u.controlnet.model_id}" strength={u.strength} guess={u.guess} start={u.start} end={u.end} mode={u.mode}')
+ elif unit_type == 'xs' and u.controlnet.model is not None:
+ active_process.append(u.process)
+ active_model.append(u.controlnet)
+ active_strength.append(float(u.strength))
+ active_start.append(float(u.start))
+ active_end.append(float(u.end))
+ shared.log.debug(f'Control ControlNet-XS unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
+ elif unit_type == 'lite' and u.controlnet.model is not None:
+ active_process.append(u.process)
+ active_model.append(u.controlnet)
+ active_strength.append(float(u.strength))
+ shared.log.debug(f'Control ControlLLite unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
+ elif unit_type == 'reference':
+ p.override = u.override
+ p.attention = u.attention
+ p.query_weight = float(u.query_weight)
+ p.adain_weight = float(u.adain_weight)
+ p.fidelity = u.fidelity
+ shared.log.debug('Control Reference unit')
+ else:
+ if u.process.processor_id is not None:
+ active_process.append(u.process)
+ shared.log.debug(f'Control process unit: i={num_units} process={u.process.processor_id}')
+ active_strength.append(float(u.strength))
+ debug(f'Control active: process={len(active_process)} model={len(active_model)}')
+ return active_process, active_model, active_strength, active_start, active_end
+
+
+def check_enabled(p, unit_type, units, active_model, active_strength, active_start, active_end):
+ has_models = False
+ selected_models: List[Union[controlnet.ControlNetModel, xs.ControlNetXSModel, t2iadapter.AdapterModel]] = None
+ control_conditioning = None
+ control_guidance_start = None
+ control_guidance_end = None
+ if unit_type == 't2i adapter' or unit_type == 'controlnet' or unit_type == 'xs' or unit_type == 'lite':
+ if len(active_model) == 0:
+ selected_models = None
+ elif len(active_model) == 1:
+ selected_models = active_model[0].model if active_model[0].model is not None else None
+ p.is_tile = p.is_tile or 'tile' in active_model[0].model_id.lower()
+ has_models = selected_models is not None
+ control_conditioning = active_strength[0] if len(active_strength) > 0 else 1 # strength or list[strength]
+ control_guidance_start = active_start[0] if len(active_start) > 0 else 0
+ control_guidance_end = active_end[0] if len(active_end) > 0 else 1
+ else:
+ selected_models = [m.model for m in active_model if m.model is not None]
+ has_models = len(selected_models) > 0
+ control_conditioning = active_strength[0] if len(active_strength) == 1 else list(active_strength) # strength or list[strength]
+ control_guidance_start = active_start[0] if len(active_start) == 1 else list(active_start)
+ control_guidance_end = active_end[0] if len(active_end) == 1 else list(active_end)
+ elif unit_type == 'reference':
+ has_models = any(u.enabled for u in units if u.type == 'reference')
+ else:
+ pass
+ return has_models, selected_models, control_conditioning, control_guidance_start, control_guidance_end
+
+
def control_set(kwargs):
if kwargs:
global p_extra_args # pylint: disable=global-statement
@@ -83,20 +245,15 @@ def control_run(state: str = '',
u.adapter.load(u.model_name, force=False)
else:
u.controlnet.load(u.model_name, force=False)
+ u.update_choices(u.model_name)
if u.process is not None and u.process.override is None and u.override is not None:
u.process.override = u.override
- global instance, pipe, original_pipeline # pylint: disable=global-statement
- t_start = time.time()
+ global pipe, original_pipeline # pylint: disable=global-statement
debug(f'Control: type={unit_type} input={inputs} init={inits} type={input_type}')
if inputs is None or (type(inputs) is list and len(inputs) == 0):
inputs = [None]
output_images: List[Image.Image] = [] # output images
- active_process: List[processors.Processor] = [] # all active preprocessors
- active_model: List[Union[controlnet.ControlNet, xs.ControlNetXS, t2iadapter.Adapter]] = [] # all active models
- active_strength: List[float] = [] # strength factors for all active models
- active_start: List[float] = [] # start step for all active models
- active_end: List[float] = [] # end step for all active models
processed_image: Image.Image = None # last processed image
if mask is not None and input_type == 0:
input_type = 1 # inpaint always requires control_image
@@ -150,10 +307,11 @@ def control_run(state: str = '',
outpath_grids=shared.opts.outdir_grids or shared.opts.outdir_control_grids,
)
p.state = state
+ p.is_tile = False
# processing.process_init(p)
resize_mode_before = resize_mode_before if resize_name_before != 'None' and inputs is not None and len(inputs) > 0 else 0
- # TODO monkey-patch for modernui missing tabs.select event
+ # TODO modernui: monkey-patch for missing tabs.select event
if selected_scale_tab_before == 0 and resize_name_before != 'None' and scale_by_before != 1 and inputs is not None and len(inputs) > 0:
shared.log.debug('Control: override resize mode=before')
selected_scale_tab_before = 1
@@ -224,155 +382,17 @@ def control_run(state: str = '',
unit_type = unit_type.strip().lower() if unit_type is not None else ''
t0 = time.time()
- num_units = 0
- for u in units:
- if u.type != unit_type:
- continue
- num_units += 1
- debug(f'Control unit: i={num_units} type={u.type} enabled={u.enabled}')
- if not u.enabled:
- if u.controlnet is not None and u.controlnet.model is not None:
- debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.cpu}')
- sd_models.move_model(u.controlnet.model, devices.cpu)
- continue
- if u.controlnet is not None and u.controlnet.model is not None:
- debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.device}')
- sd_models.move_model(u.controlnet.model, devices.device)
- if unit_type == 't2i adapter' and u.adapter.model is not None:
- active_process.append(u.process)
- active_model.append(u.adapter)
- active_strength.append(float(u.strength))
- p.adapter_conditioning_factor = u.factor
- shared.log.debug(f'Control T2I-Adapter unit: i={num_units} process={u.process.processor_id} model={u.adapter.model_id} strength={u.strength} factor={u.factor}')
- elif unit_type == 'controlnet' and u.controlnet.model is not None:
- active_process.append(u.process)
- active_model.append(u.controlnet)
- active_strength.append(float(u.strength))
- active_start.append(float(u.start))
- active_end.append(float(u.end))
- p.guess_mode = u.guess
- p.control_mode = u.mode
- shared.log.debug(f'Control ControlNet unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end} mode={u.mode}')
- elif unit_type == 'xs' and u.controlnet.model is not None:
- active_process.append(u.process)
- active_model.append(u.controlnet)
- active_strength.append(float(u.strength))
- active_start.append(float(u.start))
- active_end.append(float(u.end))
- shared.log.debug(f'Control ControlNet-XS unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
- elif unit_type == 'lite' and u.controlnet.model is not None:
- active_process.append(u.process)
- active_model.append(u.controlnet)
- active_strength.append(float(u.strength))
- shared.log.debug(f'Control ControlLLite unit: i={num_units} process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}')
- elif unit_type == 'reference':
- p.override = u.override
- p.attention = u.attention
- p.query_weight = float(u.query_weight)
- p.adain_weight = float(u.adain_weight)
- p.fidelity = u.fidelity
- shared.log.debug('Control Reference unit')
- else:
- if u.process.processor_id is not None:
- active_process.append(u.process)
- shared.log.debug(f'Control process unit: i={num_units} process={u.process.processor_id}')
- active_strength.append(float(u.strength))
- debug(f'Control active: process={len(active_process)} model={len(active_model)}')
+
+ active_process, active_model, active_strength, active_start, active_end = check_active(p, unit_type, units)
+ has_models, selected_models, control_conditioning, control_guidance_start, control_guidance_end = check_enabled(p, unit_type, units, active_model, active_strength, active_start, active_end)
processed: processing.Processed = None
image_txt = ''
info_txt = []
- has_models = False
- selected_models: List[Union[controlnet.ControlNetModel, xs.ControlNetXSModel, t2iadapter.AdapterModel]] = None
- control_conditioning = None
- control_guidance_start = None
- control_guidance_end = None
- if unit_type == 't2i adapter' or unit_type == 'controlnet' or unit_type == 'xs' or unit_type == 'lite':
- if len(active_model) == 0:
- selected_models = None
- elif len(active_model) == 1:
- selected_models = active_model[0].model if active_model[0].model is not None else None
- has_models = selected_models is not None
- control_conditioning = active_strength[0] if len(active_strength) > 0 else 1 # strength or list[strength]
- control_guidance_start = active_start[0] if len(active_start) > 0 else 0
- control_guidance_end = active_end[0] if len(active_end) > 0 else 1
- else:
- selected_models = [m.model for m in active_model if m.model is not None]
- has_models = len(selected_models) > 0
- control_conditioning = active_strength[0] if len(active_strength) == 1 else list(active_strength) # strength or list[strength]
- control_guidance_start = active_start[0] if len(active_start) == 1 else list(active_start)
- control_guidance_end = active_end[0] if len(active_end) == 1 else list(active_end)
- elif unit_type == 'reference':
- has_models = any(u.enabled for u in units if u.type == 'reference')
- else:
- pass
- def set_pipe():
- global pipe, instance # pylint: disable=global-statement
- pipe = None
- if has_models:
- p.ops.append('control')
- p.extra_generation_params["Control mode"] = unit_type # overriden later with pretty-print
- p.extra_generation_params["Control conditioning"] = control_conditioning if isinstance(control_conditioning, list) else [control_conditioning]
- p.extra_generation_params['Control start'] = control_guidance_start if isinstance(control_guidance_start, list) else [control_guidance_start]
- p.extra_generation_params['Control end'] = control_guidance_end if isinstance(control_guidance_end, list) else [control_guidance_end]
- p.extra_generation_params["Control model"] = ';'.join([(m.model_id or '') for m in active_model if m.model is not None])
- p.extra_generation_params["Control conditioning"] = ';'.join([str(c) for c in p.extra_generation_params["Control conditioning"]])
- p.extra_generation_params['Control start'] = ';'.join([str(c) for c in p.extra_generation_params['Control start']])
- p.extra_generation_params['Control end'] = ';'.join([str(c) for c in p.extra_generation_params['Control end']])
- if unit_type == 't2i adapter' and has_models:
- p.extra_generation_params["Control mode"] = 'T2I-Adapter'
- p.task_args['adapter_conditioning_scale'] = control_conditioning
- instance = t2iadapter.AdapterPipeline(selected_models, shared.sd_model)
- pipe = instance.pipeline
- if inits is not None:
- shared.log.warning('Control: T2I-Adapter does not support separate init image')
- elif unit_type == 'controlnet' and has_models:
- p.extra_generation_params["Control mode"] = 'ControlNet'
- p.task_args['controlnet_conditioning_scale'] = control_conditioning
- p.task_args['control_guidance_start'] = control_guidance_start
- p.task_args['control_guidance_end'] = control_guidance_end
- p.task_args['guess_mode'] = p.guess_mode
- instance = controlnet.ControlNetPipeline(selected_models, shared.sd_model)
- pipe = instance.pipeline
- elif unit_type == 'xs' and has_models:
- p.extra_generation_params["Control mode"] = 'ControlNet-XS'
- p.controlnet_conditioning_scale = control_conditioning
- p.control_guidance_start = control_guidance_start
- p.control_guidance_end = control_guidance_end
- instance = xs.ControlNetXSPipeline(selected_models, shared.sd_model)
- pipe = instance.pipeline
- if inits is not None:
- shared.log.warning('Control: ControlNet-XS does not support separate init image')
- elif unit_type == 'lite' and has_models:
- p.extra_generation_params["Control mode"] = 'ControlLLLite'
- p.controlnet_conditioning_scale = control_conditioning
- instance = lite.ControlLLitePipeline(shared.sd_model)
- pipe = instance.pipeline
- if inits is not None:
- shared.log.warning('Control: ControlLLLite does not support separate init image')
- elif unit_type == 'reference' and has_models:
- p.extra_generation_params["Control mode"] = 'Reference'
- p.extra_generation_params["Control attention"] = p.attention
- p.task_args['reference_attn'] = 'Attention' in p.attention
- p.task_args['reference_adain'] = 'Adain' in p.attention
- p.task_args['attention_auto_machine_weight'] = p.query_weight
- p.task_args['gn_auto_machine_weight'] = p.adain_weight
- p.task_args['style_fidelity'] = p.fidelity
- instance = reference.ReferencePipeline(shared.sd_model)
- pipe = instance.pipeline
- if inits is not None:
- shared.log.warning('Control: ControlNet-XS does not support separate init image')
- else: # run in txt2img/img2img mode
- if len(active_strength) > 0:
- p.strength = active_strength[0]
- pipe = shared.sd_model
- instance = None
- debug(f'Control: run type={unit_type} models={has_models} pipe={pipe.__class__.__name__ if pipe is not None else None}')
- return pipe
-
-
- pipe = set_pipe()
+ p.is_tile = p.is_tile and has_models
+
+ pipe = set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits)
debug(f'Control pipeline: class={pipe.__class__.__name__} args={vars(p)}')
t1, t2, t3 = time.time(), 0, 0
status = True
@@ -395,6 +415,8 @@ def set_pipe():
else:
original_pipeline = None
+ possible = sd_models.get_call(pipe).keys()
+
try:
with devices.inference_context():
if isinstance(inputs, str): # only video, the rest is a list
@@ -424,7 +446,7 @@ def set_pipe():
while status:
if pipe is None: # pipe may have been reset externally
- pipe = set_pipe()
+ pipe = set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits)
debug(f'Control pipeline reinit: class={pipe.__class__.__name__}')
processed_image = None
if frame is not None:
@@ -564,19 +586,29 @@ def set_pipe():
return [], '', '', 'Reference mode without image'
elif unit_type == 'controlnet' and has_models:
if input_type == 0: # Control only
- if shared.sd_model_type in ['f1', 'sd3'] and 'control_image' not in p.task_args:
- p.task_args['control_image'] = p.init_images # some controlnets mandate this
+ if 'control_image' in possible:
+ p.task_args['control_image'] = [p.init_images] if isinstance(p.init_images, Image.Image) else p.init_images
+ elif 'image' in possible:
+ p.task_args['image'] = [p.init_images] if isinstance(p.init_images, Image.Image) else p.init_images
+ if 'control_mode' in possible:
+ p.task_args['control_mode'] = getattr(p, 'control_mode', None)
+ if 'strength' in possible:
p.task_args['strength'] = p.denoising_strength
+ p.init_images = None
elif input_type == 1: # Init image same as control
- p.task_args['control_image'] = p.init_images # switch image and control_image
- p.task_args['strength'] = p.denoising_strength
+ if 'control_image' in possible:
+ p.task_args['control_image'] = p.init_images # switch image and control_image
+ if 'strength' in possible:
+ p.task_args['strength'] = p.denoising_strength
p.init_images = [p.override or input_image] * len(active_model)
elif input_type == 2: # Separate init image
if init_image is None:
shared.log.warning('Control: separate init image not provided')
init_image = input_image
- p.task_args['control_image'] = p.init_images # switch image and control_image
- p.task_args['strength'] = p.denoising_strength
+ if 'control_image' in possible:
+ p.task_args['control_image'] = p.init_images # switch image and control_image
+ if 'strength' in possible:
+ p.task_args['strength'] = p.denoising_strength
p.init_images = [init_image] * len(active_model)
if is_generator:
@@ -609,26 +641,31 @@ def set_pipe():
p.task_args['strength'] = denoising_strength
p.image_mask = mask
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.INPAINTING) # only controlnet supports inpaint
- elif 'control_image' in p.task_args:
+ if hasattr(p, 'init_images') and p.init_images is not None:
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) # only controlnet supports img2img
else:
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE)
- if hasattr(p, 'init_images') and p.init_images is not None:
+ if hasattr(p, 'init_images') and p.init_images is not None and 'image' in possible:
p.task_args['image'] = p.init_images # need to set explicitly for txt2img
del p.init_images
if unit_type == 'lite':
p.init_image = [input_image]
instance.apply(selected_models, processed_image, control_conditioning)
- if p.control_mode is not None:
- p.task_args['control_mode'] = p.control_mode
+ if getattr(p, 'control_mode', None) is not None:
+ p.task_args['control_mode'] = getattr(p, 'control_mode', None)
if hasattr(p, 'init_images') and p.init_images is None: # delete empty
del p.init_images
# final check
if has_models:
- if unit_type in ['controlnet', 't2i adapter', 'lite', 'xs'] and p.task_args.get('image', None) is None and getattr(p, 'init_images', None) is None:
+ if unit_type in ['controlnet', 't2i adapter', 'lite', 'xs'] \
+ and p.task_args.get('image', None) is None \
+ and p.task_args.get('control_image', None) is None \
+ and getattr(p, 'init_images', None) is None \
+ and getattr(p, 'image', None) is None:
if is_generator:
- yield terminate(f'Mode={p.extra_generation_params.get("Control mode", None)} input image is none')
+ shared.log.debug(f'Control args: {p.task_args}')
+ yield terminate(f'Mode={p.extra_generation_params.get("Control type", None)} input image is none')
return [], '', '', 'Error: Input image is none'
# resize mask
@@ -658,12 +695,19 @@ def set_pipe():
script_runner.initialize_scripts(False)
p.script_args = script.init_default_script_args(script_runner)
- processed = p.scripts.run(p, *p.script_args)
+ # actual processing
+ if p.is_tile:
+ processed: processing.Processed = tile.run_tiling(p, input_image)
+ if processed is None and p.scripts is not None:
+ processed = p.scripts.run(p, *p.script_args)
if processed is None:
processed: processing.Processed = processing.process_images(p) # run actual pipeline
else:
script_run = True
- processed = p.scripts.after(p, processed, *p.script_args)
+
+ # postprocessing
+ if p.scripts is not None:
+ processed = p.scripts.after(p, processed, *p.script_args)
output = None
if processed is not None:
output = processed.images
@@ -717,14 +761,11 @@ def set_pipe():
shared.log.error(f'Control pipeline failed: type={unit_type} units={len(active_model)} error={e}')
errors.display(e, 'Control')
- t_end = time.time()
-
if len(output_images) == 0:
output_images = None
image_txt = '| Images None'
else:
- image_str = [f'{image.width}x{image.height}' for image in output_images]
- image_txt = f'| Time {t_end-t_start:.2f}s | Images {len(output_images)} | Size {" ".join(image_str)}'
+ image_txt = ''
p.init_images = output_images # may be used for hires
if video_type != 'None' and isinstance(output_images, list):
@@ -738,10 +779,9 @@ def set_pipe():
restore_pipeline()
debug(f'Ready: {image_txt}')
- html_txt = f'Ready {image_txt}
'
+ html_txt = f'Ready {image_txt}
' if image_txt != '' else ''
if len(info_txt) > 0:
html_txt = html_txt + infotext_to_html(info_txt[0])
if is_generator:
yield (output_images, blended_image, html_txt, output_filename)
- else:
- return (output_images, blended_image, html_txt, output_filename)
+ return (output_images, blended_image, html_txt, output_filename)
diff --git a/modules/control/tile.py b/modules/control/tile.py
new file mode 100644
index 000000000..de9df1131
--- /dev/null
+++ b/modules/control/tile.py
@@ -0,0 +1,73 @@
+import time
+from PIL import Image
+from modules import shared, processing, images, sd_models
+
+
+def get_tile(image: Image.Image, x: int, y: int, sx: int, sy: int) -> Image.Image:
+ return image.crop((
+ (x + 0) * image.width // sx,
+ (y + 0) * image.height // sy,
+ (x + 1) * image.width // sx,
+ (y + 1) * image.height // sy
+ ))
+
+
+def set_tile(image: Image.Image, x: int, y: int, tiled: Image.Image):
+ image.paste(tiled, (x * tiled.width, y * tiled.height))
+ return image
+
+
+def run_tiling(p: processing.StableDiffusionProcessing, input_image: Image.Image) -> processing.Processed:
+ t0 = time.time()
+ # prepare images
+ sx, sy = p.control_tile.split('x')
+ sx = int(sx)
+ sy = int(sy)
+ if sx <= 0 or sy <= 0:
+ raise ValueError('Control Tile: invalid tile size')
+ control_image = p.task_args.get('control_image', None) or p.task_args.get('image', None)
+ control_upscaled = None
+ if isinstance(control_image, list) and len(control_image) > 0:
+ w, h = 8 * int(sx * control_image[0].width) // 8, 8 * int(sy * control_image[0].height) // 8
+ control_upscaled = images.resize_image(resize_mode=1 if sx==sy else 5, im=control_image[0], width=w, height=h, context='add with forward')
+ init_image = p.override or input_image
+ init_upscaled = None
+ if init_image is not None:
+ w, h = 8 * int(sx * init_image.width) // 8, 8 * int(sy * init_image.height) // 8
+ init_upscaled = images.resize_image(resize_mode=1 if sx==sy else 5, im=init_image, width=w, height=h, context='add with forward')
+ t1 = time.time()
+ shared.log.debug(f'Control Tile: scale={sx}x{sy} resize={"fixed" if sx==sy else "context"} control={control_upscaled} init={init_upscaled} time={t1-t0:.3f}')
+
+ # stop processing from restoring pipeline on each iteration
+ orig_restore_pipeline = getattr(shared.sd_model, 'restore_pipeline', None)
+ shared.sd_model.restore_pipeline = None
+
+ # run tiling
+ for x in range(sx):
+ for y in range(sy):
+ shared.log.info(f'Control Tile: tile={x+1}-{sx}/{y+1}-{sy} target={control_upscaled}')
+ shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE)
+ p.init_images = None
+ p.task_args['control_mode'] = p.control_mode
+ p.task_args['strength'] = p.denoising_strength
+ if init_upscaled is not None:
+ p.task_args['image'] = [get_tile(init_upscaled, x, y, sx, sy)]
+ if control_upscaled is not None:
+ p.task_args['control_image'] = [get_tile(control_upscaled, x, y, sx, sy)]
+ processed: processing.Processed = processing.process_images(p) # run actual pipeline
+ if processed is None or len(processed.images) == 0:
+ continue
+ control_upscaled = set_tile(control_upscaled, x, y, processed.images[0])
+
+ # post-process
+ p.width = control_upscaled.width
+ p.height = control_upscaled.height
+ processed.images = [control_upscaled]
+ processed.info = processed.infotext(p, 0)
+ processed.infotexts = [processed.info]
+ shared.sd_model.restore_pipeline = orig_restore_pipeline
+ if hasattr(shared.sd_model, 'restore_pipeline') and shared.sd_model.restore_pipeline is not None:
+ shared.sd_model.restore_pipeline()
+ t2 = time.time()
+ shared.log.debug(f'Control Tile: image={control_upscaled} time={t2-t0:.3f}')
+ return processed
diff --git a/modules/control/unit.py b/modules/control/unit.py
index 7dc5528a6..eeb729740 100644
--- a/modules/control/unit.py
+++ b/modules/control/unit.py
@@ -16,6 +16,22 @@
class Unit(): # mashup of gradio controls and mapping to actual implementation classes
+ def update_choices(self, model_id=None):
+ name = model_id or self.model_name
+ if name == 'InstantX Union':
+ self.choices = ['canny', 'tile', 'depth', 'blur', 'pose', 'gray', 'lq']
+ elif name == 'Shakker-Labs Union':
+ self.choices = ['canny', 'tile', 'depth', 'blur', 'pose', 'gray', 'lq']
+ elif name == 'Xinsir Union XL':
+ self.choices = ['openpose', 'depth', 'scribble', 'canny', 'normal']
+ elif name == 'Xinsir ProMax XL':
+ self.choices = ['openpose', 'depth', 'scribble', 'canny', 'normal', 'segment', 'tile', 'repaint']
+ else:
+ self.choices = ['default']
+
+ def __str__(self):
+ return f'Unit: type={self.type} enabled={self.enabled} strength={self.strength} start={self.start} end={self.end} mode={self.mode} tile={self.tile}'
+
def __init__(self,
# values
index: int = None,
@@ -38,6 +54,7 @@ def __init__(self,
control_start = None,
control_end = None,
control_mode = None,
+ control_tile = None,
result_txt = None,
extra_controls: list = [],
):
@@ -70,6 +87,10 @@ def __init__(self,
self.fidelity = 0.5
self.query_weight = 1.0
self.adain_weight = 1.0
+ # control mode
+ self.choices = ['default']
+ # control tile
+ self.tile = '1x1'
def reset():
if self.process is not None:
@@ -92,10 +113,16 @@ def control_change(start, end):
self.end = max(start, end)
def control_mode_change(mode):
- self.mode = mode - 1 if mode > 0 else None
+ self.mode = self.choices.index(mode) if mode is not None and mode in self.choices else 0
+
+ def control_tile_change(tile):
+ self.tile = tile
- def control_mode_show(model_id):
- return gr.update(visible='union' in model_id.lower())
+ def control_choices(model_id):
+ self.update_choices(model_id)
+ mode_visible = 'union' in model_id.lower() or 'promax' in model_id.lower()
+ tile_visible = 'union' in model_id.lower() or 'promax' in model_id.lower() or 'tile' in model_id.lower()
+ return [gr.update(visible=mode_visible, choices=self.choices), gr.update(visible=tile_visible)]
def adapter_extra(c1):
self.factor = c1
@@ -172,7 +199,7 @@ def set_image(image):
else:
self.controls.append(model_id)
model_id.change(fn=self.controlnet.load, inputs=[model_id], outputs=[result_txt], show_progress=True)
- model_id.change(fn=control_mode_show, inputs=[model_id], outputs=[control_mode], show_progress=False)
+ model_id.change(fn=control_choices, inputs=[model_id], outputs=[control_mode, control_tile], show_progress=False)
if extra_controls is not None and len(extra_controls) > 0:
extra_controls[0].change(fn=controlnet_extra, inputs=extra_controls)
elif self.type == 'xs':
@@ -231,3 +258,6 @@ def set_image(image):
if control_mode is not None:
self.controls.append(control_mode)
control_mode.change(fn=control_mode_change, inputs=[control_mode])
+ if control_tile is not None:
+ self.controls.append(control_tile)
+ control_tile.change(fn=control_tile_change, inputs=[control_tile])
diff --git a/modules/control/units/controlnet.py b/modules/control/units/controlnet.py
index 20b99412a..c887aca8f 100644
--- a/modules/control/units/controlnet.py
+++ b/modules/control/units/controlnet.py
@@ -5,6 +5,7 @@
from modules.control.units import detect
from modules.shared import log, opts, listdir
from modules import errors, sd_models, devices, model_quant
+from modules.processing import StableDiffusionProcessingControl
what = 'ControlNet'
@@ -51,17 +52,20 @@
'Depth Mid XL': 'diffusers/controlnet-depth-sdxl-1.0-mid',
'OpenPose XL': 'thibaud/controlnet-openpose-sdxl-1.0/bin',
'Xinsir Union XL': 'xinsir/controlnet-union-sdxl-1.0',
+ 'Xinsir ProMax XL': 'brad-twinkl/controlnet-union-sdxl-1.0-promax',
'Xinsir OpenPose XL': 'xinsir/controlnet-openpose-sdxl-1.0',
'Xinsir Canny XL': 'xinsir/controlnet-canny-sdxl-1.0',
'Xinsir Depth XL': 'xinsir/controlnet-depth-sdxl-1.0',
'Xinsir Scribble XL': 'xinsir/controlnet-scribble-sdxl-1.0',
'Xinsir Anime Painter XL': 'xinsir/anime-painter',
+ 'Xinsir Tile XL': 'xinsir/controlnet-tile-sdxl-1.0',
'NoobAI Canny XL': 'Eugeoter/noob-sdxl-controlnet-canny',
'NoobAI Lineart Anime XL': 'Eugeoter/noob-sdxl-controlnet-lineart_anime',
'NoobAI Depth XL': 'Eugeoter/noob-sdxl-controlnet-depth',
'NoobAI Normal XL': 'Eugeoter/noob-sdxl-controlnet-normal',
'NoobAI SoftEdge XL': 'Eugeoter/noob-sdxl-controlnet-softedge_hed',
'NoobAI OpenPose XL': 'einar77/noob-openpose',
+ 'TTPlanet Tile Realistic XL': 'Yakonrus/SDXL_Controlnet_Tile_Realistic_v2',
# 'StabilityAI Canny R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-canny-rank128.safetensors',
# 'StabilityAI Depth R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-depth-rank128.safetensors',
# 'StabilityAI Recolor R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-recolor-rank128.safetensors',
@@ -75,6 +79,8 @@
"InstantX Union": 'InstantX/FLUX.1-dev-Controlnet-Union',
"InstantX Canny": 'InstantX/FLUX.1-dev-Controlnet-Canny',
"JasperAI Depth": 'jasperai/Flux.1-dev-Controlnet-Depth',
+ "BlackForrestLabs Canny LoRA": '/huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora/flux1-canny-dev-lora.safetensors',
+ "BlackForrestLabs Depth LoRA": '/huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora/flux1-depth-dev-lora.safetensors',
"JasperAI Surface Normals": 'jasperai/Flux.1-dev-Controlnet-Surface-Normals',
"JasperAI Upscaler": 'jasperai/Flux.1-dev-Controlnet-Upscaler',
"Shakker-Labs Union": 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro',
@@ -85,6 +91,9 @@
"XLabs-AI HED": 'XLabs-AI/flux-controlnet-hed-diffusers'
}
predefined_sd3 = {
+ "StabilityAI Canny": 'diffusers-internal-dev/sd35-controlnet-canny-8b',
+ "StabilityAI Depth": 'diffusers-internal-dev/sd35-controlnet-depth-8b',
+ "StabilityAI Blur": 'diffusers-internal-dev/sd35-controlnet-blur-8b',
"InstantX Canny": 'InstantX/SD3-Controlnet-Canny',
"InstantX Pose": 'InstantX/SD3-Controlnet-Pose',
"InstantX Depth": 'InstantX/SD3-Controlnet-Depth',
@@ -92,6 +101,14 @@
"Alimama Inpainting": 'alimama-creative/SD3-Controlnet-Inpainting',
"Alimama SoftEdge": 'alimama-creative/SD3-Controlnet-Softedge',
}
+variants = {
+ 'NoobAI Canny XL': 'fp16',
+ 'NoobAI Lineart Anime XL': 'fp16',
+ 'NoobAI Depth XL': 'fp16',
+ 'NoobAI Normal XL': 'fp16',
+ 'NoobAI SoftEdge XL': 'fp16',
+ 'TTPlanet Tile Realistic XL': 'fp16',
+}
models = {}
all_models = {}
all_models.update(predefined_sd15)
@@ -159,26 +176,35 @@ def reset(self):
self.model = None
self.model_id = None
- def get_class(self):
- import modules.shared
- if modules.shared.sd_model_type == 'sd':
+ def get_class(self, model_id:str=''):
+ from modules import shared
+ if shared.sd_model_type == 'none':
+ _load = shared.sd_model # trigger a load
+ if shared.sd_model_type == 'sd':
from diffusers import ControlNetModel as cls # pylint: disable=reimported
config = 'lllyasviel/control_v11p_sd15_canny'
- elif modules.shared.sd_model_type == 'sdxl':
- from diffusers import ControlNetModel as cls # pylint: disable=reimported # sdxl shares same model class
- config = 'Eugeoter/noob-sdxl-controlnet-canny'
- elif modules.shared.sd_model_type == 'f1':
+ elif shared.sd_model_type == 'sdxl':
+ if 'union' in model_id.lower():
+ from diffusers import ControlNetUnionModel as cls
+ config = 'xinsir/controlnet-union-sdxl-1.0'
+ elif 'promax' in model_id.lower():
+ from diffusers import ControlNetUnionModel as cls
+ config = 'brad-twinkl/controlnet-union-sdxl-1.0-promax'
+ else:
+ from diffusers import ControlNetModel as cls # pylint: disable=reimported # sdxl shares same model class
+ config = 'Eugeoter/noob-sdxl-controlnet-canny'
+ elif shared.sd_model_type == 'f1':
from diffusers import FluxControlNetModel as cls
config = 'InstantX/FLUX.1-dev-Controlnet-Union'
- elif modules.shared.sd_model_type == 'sd3':
+ elif shared.sd_model_type == 'sd3':
from diffusers import SD3ControlNetModel as cls
config = 'InstantX/SD3-Controlnet-Canny'
else:
- log.error(f'Control {what}: type={modules.shared.sd_model_type} unsupported model')
+ log.error(f'Control {what}: type={shared.sd_model_type} unsupported model')
return None, None
return cls, config
- def load_safetensors(self, model_path):
+ def load_safetensors(self, model_id, model_path):
name = os.path.splitext(model_path)[0]
config_path = None
if not os.path.exists(model_path):
@@ -203,7 +229,7 @@ def load_safetensors(self, model_path):
config_path = f'{name}.json'
if config_path is not None:
self.load_config['original_config_file '] = config_path
- cls, config = self.get_class()
+ cls, config = self.get_class(model_id)
if cls is None:
log.error(f'Control {what} model load failed: unknown base model')
else:
@@ -225,23 +251,26 @@ def load(self, model_id: str = None, force: bool = True) -> str:
if model_path is None:
log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id')
return
+ if 'lora' in model_id.lower():
+ self.model = model_path
+ return
if model_id == self.model_id and not force:
log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded')
return
log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"')
+ cls, _config = self.get_class(model_id)
if model_path.endswith('.safetensors'):
- self.load_safetensors(model_path)
+ self.load_safetensors(model_id, model_path)
else:
kwargs = {}
if '/bin' in model_path:
model_path = model_path.replace('/bin', '')
self.load_config['use_safetensors'] = False
- cls, _config = self.get_class()
if cls is None:
log.error(f'Control {what} model load failed: id="{model_id}" unknown base model')
return
- if 'Eugeoter' in model_path:
- kwargs['variant'] = 'fp16'
+ if variants.get(model_id, None) is not None:
+ kwargs['variant'] = variants[model_id]
self.model = cls.from_pretrained(model_path, **self.load_config, **kwargs)
if self.model is None:
return
@@ -268,7 +297,7 @@ def load(self, model_id: str = None, force: bool = True) -> str:
self.model.to(self.device)
t1 = time.time()
self.model_id = model_id
- log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}')
+ log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" cls={cls.__name__} time={t1-t0:.2f}')
return f'{what} loaded model: {model_id}'
except Exception as e:
log.error(f'Control {what} model load failed: id="{model_id}" error={e}')
@@ -281,16 +310,27 @@ def __init__(self,
controlnet: Union[ControlNetModel, list[ControlNetModel]],
pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline, FluxPipeline, StableDiffusion3Pipeline],
dtype = None,
+ p: StableDiffusionProcessingControl = None, # pylint: disable=unused-argument
):
t0 = time.time()
self.orig_pipeline = pipeline
self.pipeline = None
+
+ controlnets = controlnet if isinstance(controlnet, list) else [controlnet]
+ loras = [cn for cn in controlnets if isinstance(cn, str)]
+ controlnets = [cn for cn in controlnets if not isinstance(cn, str)]
+
if pipeline is None:
log.error('Control model pipeline: model not loaded')
return
- elif detect.is_sdxl(pipeline):
- from diffusers import StableDiffusionXLControlNetPipeline
- self.pipeline = StableDiffusionXLControlNetPipeline(
+ elif detect.is_sdxl(pipeline) and len(controlnets) > 0:
+ from diffusers import StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetUnionPipeline
+ if controlnet.__class__.__name__ == 'ControlNetUnionModel':
+ cls = StableDiffusionXLControlNetUnionPipeline
+ controlnets = controlnets[0] # using only first one
+ else:
+ cls = StableDiffusionXLControlNetPipeline
+ self.pipeline = cls(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
text_encoder_2=pipeline.text_encoder_2,
@@ -299,9 +339,9 @@ def __init__(self,
unet=pipeline.unet,
scheduler=pipeline.scheduler,
feature_extractor=getattr(pipeline, 'feature_extractor', None),
- controlnet=controlnet, # can be a list
+ controlnet=controlnets, # can be a list
)
- elif detect.is_sd15(pipeline):
+ elif detect.is_sd15(pipeline) and len(controlnets) > 0:
from diffusers import StableDiffusionControlNetPipeline
self.pipeline = StableDiffusionControlNetPipeline(
vae=pipeline.vae,
@@ -312,10 +352,10 @@ def __init__(self,
feature_extractor=getattr(pipeline, 'feature_extractor', None),
requires_safety_checker=False,
safety_checker=None,
- controlnet=controlnet, # can be a list
+ controlnet=controlnets, # can be a list
)
sd_models.move_model(self.pipeline, pipeline.device)
- elif detect.is_f1(pipeline):
+ elif detect.is_f1(pipeline) and len(controlnets) > 0:
from diffusers import FluxControlNetPipeline
self.pipeline = FluxControlNetPipeline(
vae=pipeline.vae.to(devices.device),
@@ -325,9 +365,9 @@ def __init__(self,
tokenizer_2=pipeline.tokenizer_2,
transformer=pipeline.transformer,
scheduler=pipeline.scheduler,
- controlnet=controlnet, # can be a list
+ controlnet=controlnets, # can be a list
)
- elif detect.is_sd3(pipeline):
+ elif detect.is_sd3(pipeline) and len(controlnets) > 0:
from diffusers import StableDiffusion3ControlNetPipeline
self.pipeline = StableDiffusion3ControlNetPipeline(
vae=pipeline.vae,
@@ -339,8 +379,18 @@ def __init__(self,
tokenizer_3=pipeline.tokenizer_3,
transformer=pipeline.transformer,
scheduler=pipeline.scheduler,
- controlnet=controlnet, # can be a list
+ controlnet=controlnets, # can be a list
)
+ elif len(loras) > 0:
+ self.pipeline = pipeline
+ for lora in loras:
+ log.debug(f'Control {what} pipeline: lora="{lora}"')
+ lora = lora.replace('/huggingface.co/', '')
+ self.pipeline.load_lora_weights(lora)
+ """
+ if p is not None:
+ p.prompt += f''
+ """
else:
log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type')
return
@@ -350,6 +400,7 @@ def __init__(self,
return
if dtype is not None:
self.pipeline = self.pipeline.to(dtype)
+
if opts.diffusers_offload_mode == 'none':
sd_models.move_model(self.pipeline, devices.device)
from modules.sd_models import set_diffuser_offload
@@ -359,5 +410,6 @@ def __init__(self,
log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}')
def restore(self):
+ self.pipeline.unload_lora_weights()
self.pipeline = None
return self.orig_pipeline
diff --git a/modules/ctrlx/__init__.py b/modules/ctrlx/__init__.py
index 87ff52f6c..07d06aeac 100644
--- a/modules/ctrlx/__init__.py
+++ b/modules/ctrlx/__init__.py
@@ -136,7 +136,7 @@ def appearance_guidance_scale(self):
@torch.no_grad()
def __call__(
self,
- prompt: Union[str, List[str]] = None, # TODO: Support prompt_2 and negative_prompt_2
+ prompt: Union[str, List[str]] = None,
structure_prompt: Optional[Union[str, List[str]]] = None,
appearance_prompt: Optional[Union[str, List[str]]] = None,
structure_image: Optional[PipelineImageInput] = None,
@@ -180,7 +180,6 @@ def __call__(
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
- # TODO: Add function argument documentation
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
@@ -205,7 +204,7 @@ def __call__(
target_size = target_size or (height, width)
# 1. Check inputs. Raise error if not correct
- self.check_inputs( # TODO: Custom check_inputs for our method
+ self.check_inputs(
prompt,
None, # prompt_2
height,
@@ -425,7 +424,7 @@ def denoising_value_valid(dnv):
# 7.2 Optionally get guidance scale embedding
timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None: # TODO: Make guidance scale embedding work with batch_order
+ if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
@@ -457,7 +456,6 @@ def denoising_value_valid(dnv):
register_attr(self, t=t.item(), do_control=True, batch_order=batch_order)
- # TODO: For now, assume we are doing classifier-free guidance, support no CF-guidance later
latent_model_input = self.scheduler.scale_model_input(latents, t)
structure_latent_model_input = self.scheduler.scale_model_input(structure_latents, t)
appearance_latent_model_input = self.scheduler.scale_model_input(appearance_latents, t)
@@ -563,7 +561,7 @@ def denoising_value_valid(dnv):
# Self-recurrence
for _ in range(self_recurrence_schedule[i]):
if hasattr(self.scheduler, "_step_index"): # For fancier schedulers
- self.scheduler._step_index -= 1 # TODO: Does this actually work?
+ self.scheduler._step_index -= 1
t_prev = 0 if i + 1 >= num_inference_steps else timesteps[i + 1]
latents = noise_t2t(self.scheduler, t_prev, t, latents)
diff --git a/modules/devices.py b/modules/devices.py
index 56ac50091..949fab4aa 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -186,53 +186,68 @@ def get_device_for(task): # pylint: disable=unused-argument
return get_optimal_device()
-def torch_gc(force=False, fast=False):
+def torch_gc(force:bool=False, fast:bool=False, reason:str=None):
+ def get_stats():
+ mem_dict = memstats.memory_stats()
+ gpu_dict = mem_dict.get('gpu', {})
+ ram_dict = mem_dict.get('ram', {})
+ oom = gpu_dict.get('oom', 0)
+ ram = ram_dict.get('used', 0)
+ if backend == "directml":
+ gpu = torch.cuda.memory_allocated() / (1 << 30)
+ else:
+ gpu = gpu_dict.get('used', 0)
+ used_gpu = round(100 * gpu / gpu_dict.get('total', 1)) if gpu_dict.get('total', 1) > 1 else 0
+ used_ram = round(100 * ram / ram_dict.get('total', 1)) if ram_dict.get('total', 1) > 1 else 0
+ return gpu, used_gpu, ram, used_ram, oom
+
+ global previous_oom # pylint: disable=global-statement
import gc
from modules import timer, memstats
from modules.shared import cmd_opts
+
t0 = time.time()
- mem = memstats.memory_stats()
- gpu = mem.get('gpu', {})
- ram = mem.get('ram', {})
- oom = gpu.get('oom', 0)
- if backend == "directml":
- used_gpu = round(100 * torch.cuda.memory_allocated() / (1 << 30) / gpu.get('total', 1)) if gpu.get('total', 1) > 1 else 0
- else:
- used_gpu = round(100 * gpu.get('used', 0) / gpu.get('total', 1)) if gpu.get('total', 1) > 1 else 0
- used_ram = round(100 * ram.get('used', 0) / ram.get('total', 1)) if ram.get('total', 1) > 1 else 0
- global previous_oom # pylint: disable=global-statement
+ gpu, used_gpu, ram, _used_ram, oom = get_stats()
threshold = 0 if (cmd_opts.lowvram and not cmd_opts.use_zluda) else opts.torch_gc_threshold
collected = 0
- if force or threshold == 0 or used_gpu >= threshold or used_ram >= threshold:
+ if reason is None and force:
+ reason='force'
+ if threshold == 0 or used_gpu >= threshold:
force = True
+ if reason is None:
+ reason = 'threshold'
if oom > previous_oom:
previous_oom = oom
- log.warning(f'Torch GPU out-of-memory error: {mem}')
+ log.warning(f'Torch GPU out-of-memory error: {memstats.memory_stats()}')
force = True
+ if reason is None:
+ reason = 'oom'
if force:
# actual gc
collected = gc.collect() if not fast else 0 # python gc
if cuda_ok:
try:
with torch.cuda.device(get_cuda_device_string()):
+ torch.cuda.synchronize()
torch.cuda.empty_cache() # cuda gc
torch.cuda.ipc_collect()
except Exception:
pass
+ else:
+ return gpu, ram
t1 = time.time()
- if 'gc' not in timer.process.records:
- timer.process.records['gc'] = 0
- timer.process.records['gc'] += t1 - t0
- if not force or collected == 0:
- return
- mem = memstats.memory_stats()
- saved = round(gpu.get('used', 0) - mem.get('gpu', {}).get('used', 0), 2)
- before = { 'gpu': gpu.get('used', 0), 'ram': ram.get('used', 0) }
- after = { 'gpu': mem.get('gpu', {}).get('used', 0), 'ram': mem.get('ram', {}).get('used', 0), 'retries': mem.get('retries', 0), 'oom': mem.get('oom', 0) }
- utilization = { 'gpu': used_gpu, 'ram': used_ram, 'threshold': threshold }
- results = { 'collected': collected, 'saved': saved }
+ timer.process.add('gc', t1 - t0)
+ if fast:
+ return gpu, ram
+
+ new_gpu, new_used_gpu, new_ram, new_used_ram, oom = get_stats()
+ before = { 'gpu': gpu, 'ram': ram }
+ after = { 'gpu': new_gpu, 'ram': new_ram, 'oom': oom }
+ utilization = { 'gpu': new_used_gpu, 'ram': new_used_ram }
+ results = { 'gpu': round(gpu - new_gpu, 2), 'py': collected }
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
- log.debug(f'GC: utilization={utilization} gc={results} before={before} after={after} device={torch.device(get_optimal_device_name())} fn={fn} time={round(t1 - t0, 2)}') # pylint: disable=protected-access
+ log.debug(f'GC: current={after} prev={before} load={utilization} gc={results} fn={fn} why={reason} time={t1-t0:.2f}')
+ return new_gpu, new_ram
def set_cuda_sync_mode(mode):
@@ -471,7 +486,7 @@ def set_cuda_params():
device_name = get_raw_openvino_device()
else:
device_name = torch.device(get_optimal_device_name())
- log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} vae={dtype_vae} unet={dtype_unet} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upscast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} test-fp16={fp16_ok} test-bf16={bf16_ok} optimization="{opts.cross_attention_optimization}"')
+ log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} vae={dtype_vae} unet={dtype_unet} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} test-fp16={fp16_ok} test-bf16={bf16_ok} optimization="{opts.cross_attention_optimization}"')
def cond_cast_unet(tensor):
diff --git a/modules/devices_mac.py b/modules/devices_mac.py
index fe7c80f31..2ddc1fc22 100644
--- a/modules/devices_mac.py
+++ b/modules/devices_mac.py
@@ -24,7 +24,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): # pylint: disable=redefined
output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
- elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
+ elif output_dtype == torch.bool or (cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16)):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs)
@@ -42,7 +42,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): # pylint: disable=redefined
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
- lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
+ lambda _, self, *args, **kwargs: self.device.type != 'mps' and ((args and isinstance(args[0], torch.device) and args[0].type == 'mps') or (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')))
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
diff --git a/modules/errors.py b/modules/errors.py
index 527884cf1..6302057d7 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -36,7 +36,7 @@ def print_error_explanation(message):
log.error(line)
-def display(e: Exception, task, suppress=[]):
+def display(e: Exception, task: str, suppress=[]):
log.error(f"{task or 'error'}: {type(e).__name__}")
console.print_exception(show_locals=False, max_frames=16, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width)
@@ -48,7 +48,7 @@ def display_once(e: Exception, task):
already_displayed[task] = 1
-def run(code, task):
+def run(code, task: str):
try:
code()
except Exception as e:
@@ -59,14 +59,14 @@ def exception(suppress=[]):
console.print_exception(show_locals=False, max_frames=16, extra_lines=2, suppress=suppress, theme="ansi_dark", word_wrap=False, width=min([console.width, 200]))
-def profile(profiler, msg: str, n: int = 5):
+def profile(profiler, msg: str, n: int = 16):
profiler.disable()
import io
import pstats
stream = io.StringIO() # pylint: disable=abstract-class-instantiated
p = pstats.Stats(profiler, stream=stream)
p.sort_stats(pstats.SortKey.CUMULATIVE)
- p.print_stats(100)
+ p.print_stats(200)
# p.print_title()
# p.print_call_heading(10, 'time')
# p.print_callees(10)
@@ -81,6 +81,7 @@ def profile(profiler, msg: str, n: int = 5):
and '_lsprof' not in x
and '/profiler' not in x
and 'rich' not in x
+ and 'profile_torch' not in x
and x.strip() != ''
]
txt = '\n'.join(lines[:min(n, len(lines))])
diff --git a/modules/extensions.py b/modules/extensions.py
index 5a8a53d29..ccd92dbf0 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -154,4 +154,4 @@ def list_extensions():
for dirname, path, is_builtin in extension_paths:
extension = Extension(name=dirname, path=path, enabled=dirname not in disabled_extensions, is_builtin=is_builtin)
extensions.append(extension)
- shared.log.info(f'Disabled extensions: {[e.name for e in extensions if not e.enabled]}')
+ shared.log.debug(f'Disabled extensions: {[e.name for e in extensions if not e.enabled]}')
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index b464bd349..fe141cca1 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -1,6 +1,7 @@
import re
+import inspect
from collections import defaultdict
-from modules import errors, shared, devices
+from modules import errors, shared
extra_network_registry = {}
@@ -15,10 +16,14 @@ def register_extra_network(extra_network):
def register_default_extra_networks():
- from modules.ui_extra_networks_hypernet import ExtraNetworkHypernet
- register_extra_network(ExtraNetworkHypernet())
from modules.ui_extra_networks_styles import ExtraNetworkStyles
register_extra_network(ExtraNetworkStyles())
+ if shared.native:
+ from modules.lora.networks import extra_network_lora
+ register_extra_network(extra_network_lora)
+ if shared.opts.hypernetwork_enabled:
+ from modules.ui_extra_networks_hypernet import ExtraNetworkHypernet
+ register_extra_network(ExtraNetworkHypernet())
class ExtraNetworkParams:
@@ -70,9 +75,12 @@ def is_stepwise(en_obj):
return any([len(str(x).split("@")) > 1 for x in all_args]) # noqa C419 # pylint: disable=use-a-generator
-def activate(p, extra_network_data, step=0):
+def activate(p, extra_network_data=None, step=0, include=[], exclude=[]):
"""call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list"""
- if extra_network_data is None:
+ if p.disable_extra_networks:
+ return
+ extra_network_data = extra_network_data or p.network_data
+ if extra_network_data is None or len(extra_network_data) == 0:
return
stepwise = False
for extra_network_args in extra_network_data.values():
@@ -82,35 +90,42 @@ def activate(p, extra_network_data, step=0):
shared.log.warning("Composable LoRA not compatible with 'lora_force_diffusers'")
stepwise = False
shared.opts.data['lora_functional'] = stepwise or functional
- with devices.autocast():
- for extra_network_name, extra_network_args in extra_network_data.items():
- extra_network = extra_network_registry.get(extra_network_name, None)
- if extra_network is None:
- errors.log.warning(f"Skipping unknown extra network: {extra_network_name}")
- continue
- try:
+
+ for extra_network_name, extra_network_args in extra_network_data.items():
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ errors.log.warning(f"Skipping unknown extra network: {extra_network_name}")
+ continue
+ try:
+ signature = list(inspect.signature(extra_network.activate).parameters)
+ if 'include' in signature and 'exclude' in signature:
+ extra_network.activate(p, extra_network_args, step=step, include=include, exclude=exclude)
+ else:
extra_network.activate(p, extra_network_args, step=step)
- except Exception as e:
- errors.display(e, f"Activating network: type={extra_network_name} args:{extra_network_args}")
-
- for extra_network_name, extra_network in extra_network_registry.items():
- args = extra_network_data.get(extra_network_name, None)
- if args is not None:
- continue
- try:
- extra_network.activate(p, [])
- except Exception as e:
- errors.display(e, f"Activating network: type={extra_network_name}")
-
- p.extra_network_data = extra_network_data
+ except Exception as e:
+ errors.display(e, f"Activating network: type={extra_network_name} args:{extra_network_args}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+ try:
+ extra_network.activate(p, [])
+ except Exception as e:
+ errors.display(e, f"Activating network: type={extra_network_name}")
+
+ p.network_data = extra_network_data
if stepwise:
p.stepwise_lora = True
shared.opts.data['lora_functional'] = functional
-def deactivate(p, extra_network_data):
+def deactivate(p, extra_network_data=None):
"""call deactivate for extra networks in extra_network_data in specified order, then call deactivate for all remaining registered networks"""
- if extra_network_data is None:
+ if p.disable_extra_networks:
+ return
+ extra_network_data = extra_network_data or p.network_data
+ if extra_network_data is None or len(extra_network_data) == 0:
return
for extra_network_name in extra_network_data:
extra_network = extra_network_registry.get(extra_network_name, None)
diff --git a/modules/face/faceid.py b/modules/face/faceid.py
index b74e15dc5..4a4f07531 100644
--- a/modules/face/faceid.py
+++ b/modules/face/faceid.py
@@ -204,7 +204,6 @@ def face_id(
ip_model_dict["face_image"] = face_images
ip_model_dict["faceid_embeds"] = face_embeds # overwrite placeholder
faceid_model.set_scale(scale)
- extra_network_data = None
if p.all_prompts is None or len(p.all_prompts) == 0:
processing.process_init(p)
@@ -215,11 +214,9 @@ def face_id(
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n+1) * p.batch_size]
p.seeds = p.all_seeds[n * p.batch_size:(n+1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n+1) * p.batch_size]
- p.prompts, extra_network_data = extra_networks.parse_prompts(p.prompts)
+ p.prompts, p.network_data = extra_networks.parse_prompts(p.prompts)
- if not p.disable_extra_networks:
- with devices.autocast():
- extra_networks.activate(p, extra_network_data)
+ extra_networks.activate(p, p.network_data)
ip_model_dict.update({
"prompt": p.prompts[0],
"negative_prompt": p.negative_prompts[0],
@@ -239,8 +236,7 @@ def face_id(
devices.torch_gc()
ipadapter.unapply(p.sd_model)
- if not p.disable_extra_networks:
- extra_networks.deactivate(p, extra_network_data)
+ extra_networks.deactivate(p, p.network_data)
p.extra_generation_params["IP Adapter"] = f"{basename}:{scale}"
finally:
diff --git a/modules/face/instantid_model.py b/modules/face/instantid_model.py
index 51a4d7850..543b39ded 100644
--- a/modules/face/instantid_model.py
+++ b/modules/face/instantid_model.py
@@ -344,7 +344,6 @@ def __call__(
return hidden_states
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
- # TODO attention_mask
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
diff --git a/modules/face/photomaker_model.py b/modules/face/photomaker_model.py
index b62fe73b8..3595c6a36 100644
--- a/modules/face/photomaker_model.py
+++ b/modules/face/photomaker_model.py
@@ -244,7 +244,7 @@ def encode_prompt_with_trigger_word(
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
- input_ids = tokenizer.encode(prompt) # TODO: batch encode
+ input_ids = tokenizer.encode(prompt)
clean_index = 0
clean_input_ids = []
class_token_index = []
@@ -296,7 +296,7 @@ def encode_prompt_with_trigger_word(
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case
+ class_tokens_mask = class_tokens_mask.to(device=device)
return prompt_embeds, pooled_prompt_embeds, class_tokens_mask
@@ -332,7 +332,7 @@ def __call__(
callback_steps: int = 1,
# Added parameters (for PhotoMaker)
input_id_images: PipelineImageInput = None,
- start_merge_step: int = 0, # TODO: change to `style_strength_ratio` in the future
+ start_merge_step: int = 0,
class_tokens_mask: Optional[torch.LongTensor] = None,
prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
@@ -410,7 +410,7 @@ def __call__(
(
prompt_embeds_text_only,
negative_prompt_embeds,
- pooled_prompt_embeds_text_only, # TODO: replace the pooled_prompt_embeds with text only prompt
+ pooled_prompt_embeds_text_only,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt_text_only,
@@ -431,7 +431,7 @@ def __call__(
if not isinstance(input_id_images[0], torch.Tensor):
id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values
- id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts
+ id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype)
# 6. Get the update text embedding with the stacked ID embedding
prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask)
diff --git a/modules/files_cache.py b/modules/files_cache.py
index fa2241afc..d65e0f4f4 100644
--- a/modules/files_cache.py
+++ b/modules/files_cache.py
@@ -6,6 +6,7 @@
from installer import log
+do_cache_folders = os.environ.get('SD_NO_CACHE', None) is None
class Directory: # forward declaration
...
@@ -87,8 +88,6 @@ def is_stale(self) -> bool:
return not self.is_directory or self.mtime != self.live_mtime
-
-
class DirectoryCache(UserDict, DirectoryCollection):
def __delattr__(self, directory_path: str) -> None:
directory: Directory = get_directory(directory_path, fetch=False)
@@ -126,7 +125,7 @@ def clean_directory(directory: Directory, /, recursive: RecursiveType=False) ->
return is_clean
-def get_directory(directory_or_path: str, /, fetch:bool=True) -> Union[Directory, None]:
+def get_directory(directory_or_path: str, /, fetch: bool=True) -> Union[Directory, None]:
if isinstance(directory_or_path, Directory):
if directory_or_path.is_directory:
return directory_or_path
@@ -136,8 +135,9 @@ def get_directory(directory_or_path: str, /, fetch:bool=True) -> Union[Directory
if not cache_folders.get(directory_or_path, None):
if fetch:
directory = fetch_directory(directory_path=directory_or_path)
- if directory:
+ if directory and do_cache_folders:
cache_folders[directory_or_path] = directory
+ return directory
else:
clean_directory(cache_folders[directory_or_path])
return cache_folders[directory_or_path] if directory_or_path in cache_folders else None
diff --git a/modules/freescale/__init__.py b/modules/freescale/__init__.py
new file mode 100644
index 000000000..7b9c17f5d
--- /dev/null
+++ b/modules/freescale/__init__.py
@@ -0,0 +1,4 @@
+# Credits: https://github.com/ali-vilab/FreeScale
+
+from .freescale_pipeline import StableDiffusionXLFreeScale
+from .freescale_pipeline_img2img import StableDiffusionXLFreeScaleImg2Img
diff --git a/modules/freescale/free_lunch_utils.py b/modules/freescale/free_lunch_utils.py
new file mode 100644
index 000000000..be26b732a
--- /dev/null
+++ b/modules/freescale/free_lunch_utils.py
@@ -0,0 +1,305 @@
+from typing import Any, Dict, Optional, Tuple
+import torch
+import torch.fft as fft
+from diffusers.utils import is_torch_version
+
+""" Borrowed from https://github.com/ChenyangSi/FreeU/blob/main/demo/free_lunch_utils.py
+"""
+
+def isinstance_str(x: object, cls_name: str):
+ """
+ Checks whether x has any class *named* cls_name in its ancestry.
+ Doesn't require access to the class's implementation.
+
+ Useful for patching!
+ """
+
+ for _cls in x.__class__.__mro__:
+ if _cls.__name__ == cls_name:
+ return True
+
+ return False
+
+
+def Fourier_filter(x, threshold, scale):
+ dtype = x.dtype
+ x = x.type(torch.float32)
+ # FFT
+ x_freq = fft.fftn(x, dim=(-2, -1))
+ x_freq = fft.fftshift(x_freq, dim=(-2, -1))
+
+ B, C, H, W = x_freq.shape
+ mask = torch.ones((B, C, H, W)).cuda()
+
+ crow, ccol = H // 2, W //2
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
+ x_freq = x_freq * mask
+
+ # IFFT
+ x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
+ x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
+
+ x_filtered = x_filtered.type(dtype)
+ return x_filtered
+
+
+def register_upblock2d(model):
+ def up_forward(self):
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ #print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ return forward
+
+ for i, upsample_block in enumerate(model.unet.up_blocks):
+ if isinstance_str(upsample_block, "UpBlock2D"):
+ upsample_block.forward = up_forward(upsample_block)
+
+
+def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
+ def up_forward(self):
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
+
+ # --------------- FreeU code -----------------------
+ # Only operate on the first two stages
+ if hidden_states.shape[1] == 1280:
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
+ if hidden_states.shape[1] == 640:
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
+ # ---------------------------------------------------------
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ return forward
+
+ for i, upsample_block in enumerate(model.unet.up_blocks):
+ if isinstance_str(upsample_block, "UpBlock2D"):
+ upsample_block.forward = up_forward(upsample_block)
+ setattr(upsample_block, 'b1', b1)
+ setattr(upsample_block, 'b2', b2)
+ setattr(upsample_block, 's1', s1)
+ setattr(upsample_block, 's2', s2)
+
+
+def register_crossattn_upblock2d(model):
+ def up_forward(self):
+ def forward(
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ None, # timestep
+ None, # class_labels
+ cross_attention_kwargs,
+ attention_mask,
+ encoder_attention_mask,
+ **ckpt_kwargs,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ return forward
+
+ for i, upsample_block in enumerate(model.unet.up_blocks):
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
+ upsample_block.forward = up_forward(upsample_block)
+
+
+def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
+ def up_forward(self):
+ def forward(
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # --------------- FreeU code -----------------------
+ # Only operate on the first two stages
+ if hidden_states.shape[1] == 1280:
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
+ if hidden_states.shape[1] == 640:
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
+ # ---------------------------------------------------------
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ None, # timestep
+ None, # class_labels
+ cross_attention_kwargs,
+ attention_mask,
+ encoder_attention_mask,
+ **ckpt_kwargs,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ # hidden_states = attn(
+ # hidden_states,
+ # encoder_hidden_states=encoder_hidden_states,
+ # cross_attention_kwargs=cross_attention_kwargs,
+ # encoder_attention_mask=encoder_attention_mask,
+ # return_dict=False,
+ # )[0]
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )[0]
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ return forward
+
+ for i, upsample_block in enumerate(model.unet.up_blocks):
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
+ upsample_block.forward = up_forward(upsample_block)
+ setattr(upsample_block, 'b1', b1)
+ setattr(upsample_block, 'b2', b2)
+ setattr(upsample_block, 's1', s1)
+ setattr(upsample_block, 's2', s2)
diff --git a/modules/freescale/freescale_pipeline.py b/modules/freescale/freescale_pipeline.py
new file mode 100644
index 000000000..9b7a68b68
--- /dev/null
+++ b/modules/freescale/freescale_pipeline.py
@@ -0,0 +1,1189 @@
+from inspect import isfunction
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+import inspect
+import os
+import random
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention_processor import AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.utils import is_accelerate_available, is_accelerate_version, logging, replace_example_docstring
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.models.attention import BasicTransformerBlock
+
+from .scale_attention import ori_forward, scale_forward
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLPipeline
+
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+def exists(val):
+ return val is not None
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+to_torch = partial(torch.tensor, dtype=torch.float16)
+betas = make_beta_schedule("linear", 1000, linear_start=0.00085, linear_end=0.012)
+alphas = 1. - betas
+alphas_cumprod = np.cumprod(alphas, axis=0)
+sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod))
+sqrt_one_minus_alphas_cumprod = to_torch(np.sqrt(1. - alphas_cumprod))
+
+def q_sample(x_start, t, init_noise_sigma = 1.0, noise=None, device=None):
+ noise = default(noise, lambda: torch.randn_like(x_start)).to(device) * init_noise_sigma
+ return (extract_into_tensor(sqrt_alphas_cumprod.to(device), t, x_start.shape) * x_start +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(device), t, x_start.shape) * noise)
+
+def get_views(height, width, h_window_size=128, w_window_size=128, h_window_stride=64, w_window_stride=64, vae_scale_factor=8):
+ height //= vae_scale_factor
+ width //= vae_scale_factor
+ num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
+ num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
+ views = []
+ for i in range(total_num_blocks):
+ h_start = int((i // num_blocks_width) * h_window_stride)
+ h_end = h_start + h_window_size
+ w_start = int((i % num_blocks_width) * w_window_stride)
+ w_end = w_start + w_window_size
+
+ if h_end > height:
+ h_start = int(h_start + height - h_end)
+ h_end = int(height)
+ if w_end > width:
+ w_start = int(w_start + width - w_end)
+ w_end = int(width)
+ if h_start < 0:
+ h_end = int(h_end - h_start)
+ h_start = 0
+ if w_start < 0:
+ w_end = int(w_end - w_start)
+ w_start = 0
+
+ random_jitter = True
+ if random_jitter:
+ h_jitter_range = (h_window_size - h_window_stride) // 4
+ w_jitter_range = (w_window_size - w_window_stride) // 4
+ h_jitter = 0
+ w_jitter = 0
+
+ if (w_start != 0) and (w_end != width):
+ w_jitter = random.randint(-w_jitter_range, w_jitter_range)
+ elif (w_start == 0) and (w_end != width):
+ w_jitter = random.randint(-w_jitter_range, 0)
+ elif (w_start != 0) and (w_end == width):
+ w_jitter = random.randint(0, w_jitter_range)
+ if (h_start != 0) and (h_end != height):
+ h_jitter = random.randint(-h_jitter_range, h_jitter_range)
+ elif (h_start == 0) and (h_end != height):
+ h_jitter = random.randint(-h_jitter_range, 0)
+ elif (h_start != 0) and (h_end == height):
+ h_jitter = random.randint(0, h_jitter_range)
+ h_start += (h_jitter + h_jitter_range)
+ h_end += (h_jitter + h_jitter_range)
+ w_start += (w_jitter + w_jitter_range)
+ w_end += (w_jitter + w_jitter_range)
+
+ views.append((h_start, h_end, w_start, w_end))
+ return views
+
+def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
+ x_coord = torch.arange(kernel_size)
+ gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
+ gaussian_1d = gaussian_1d / gaussian_1d.sum()
+ gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
+ kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
+
+ return kernel
+
+def gaussian_filter(latents, kernel_size=3, sigma=1.0):
+ channels = latents.shape[1]
+ kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
+ if len(latents.shape) == 5:
+ b = latents.shape[0]
+ latents = rearrange(latents, 'b c t i j -> (b t) c i j')
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
+ blurred_latents = rearrange(blurred_latents, '(b t) c i j -> b c t i j', b=b)
+ else:
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
+
+ return blurred_latents
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLFreeScale(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.default_sample_size = self.unet.config.sample_size
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ model_sequence = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+ model_sequence.extend([self.unet, self.vae])
+
+ hook = None
+ for cpu_offloaded_model in model_sequence:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ resolutions_list: Optional[Union[int, List[int]]] = None,
+ restart_steps: Optional[Union[int, List[int]]] = None,
+ cosine_scale: float = 2.0,
+ dilate_tau: int = 35,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ # 0. Default height and width to unet
+ if resolutions_list:
+ height, width = resolutions_list[0]
+ target_sizes = resolutions_list[1:]
+ if not restart_steps:
+ restart_steps = [15] * len(target_sizes)
+ else:
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 9.1 Apply denoising_end
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ results_list = []
+
+ for block in self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks:
+ for module in block.modules():
+ if isinstance(module, BasicTransformerBlock):
+ module.forward = ori_forward.__get__(module, BasicTransformerBlock)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+ results_list.append(latents)
+
+ for restart_index, target_size in enumerate(target_sizes):
+ restart_step = restart_steps[restart_index]
+ target_size_ = [target_size[0]//8, target_size[1]//8]
+
+ for block in self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks:
+ for module in block.modules():
+ if isinstance(module, BasicTransformerBlock):
+ module.forward = scale_forward.__get__(module, BasicTransformerBlock)
+ module.current_hw = target_size
+
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ latents = latents / self.vae.config.scaling_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = torch.nn.functional.interpolate(
+ image,
+ size=target_size,
+ mode='bicubic',
+ )
+ latents = self.vae.encode(image).latent_dist.sample().to(self.vae.dtype)
+ latents = latents * self.vae.config.scaling_factor
+
+ noise_latents = []
+ noise = torch.randn_like(latents)
+ for timestep in self.scheduler.timesteps:
+ noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0))
+ noise_latents.append(noise_latent)
+ latents = noise_latents[restart_step]
+
+ self.scheduler._step_index = 0
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+
+ if i < restart_step:
+ self.scheduler._step_index += 1
+ progress_bar.update()
+ continue
+
+ cosine_factor = 0.5 * (1 + torch.cos(torch.pi * (self.scheduler.config.num_train_timesteps - t) / self.scheduler.config.num_train_timesteps)).cpu()
+ c1 = cosine_factor ** cosine_scale
+ latents = latents * (1 - c1) + noise_latents[i] * c1
+
+ dilate_coef=target_size[1]//1024
+
+ dilate_layers = [
+ # "down_blocks.1.resnets.0.conv1",
+ # "down_blocks.1.resnets.0.conv2",
+ # "down_blocks.1.resnets.1.conv1",
+ # "down_blocks.1.resnets.1.conv2",
+ "down_blocks.1.downsamplers.0.conv",
+ "down_blocks.2.resnets.0.conv1",
+ "down_blocks.2.resnets.0.conv2",
+ "down_blocks.2.resnets.1.conv1",
+ "down_blocks.2.resnets.1.conv2",
+ # "up_blocks.0.resnets.0.conv1",
+ # "up_blocks.0.resnets.0.conv2",
+ # "up_blocks.0.resnets.1.conv1",
+ # "up_blocks.0.resnets.1.conv2",
+ # "up_blocks.0.resnets.2.conv1",
+ # "up_blocks.0.resnets.2.conv2",
+ # "up_blocks.0.upsamplers.0.conv",
+ # "up_blocks.1.resnets.0.conv1",
+ # "up_blocks.1.resnets.0.conv2",
+ # "up_blocks.1.resnets.1.conv1",
+ # "up_blocks.1.resnets.1.conv2",
+ # "up_blocks.1.resnets.2.conv1",
+ # "up_blocks.1.resnets.2.conv2",
+ # "up_blocks.1.upsamplers.0.conv",
+ # "up_blocks.2.resnets.0.conv1",
+ # "up_blocks.2.resnets.0.conv2",
+ # "up_blocks.2.resnets.1.conv1",
+ # "up_blocks.2.resnets.1.conv2",
+ # "up_blocks.2.resnets.2.conv1",
+ # "up_blocks.2.resnets.2.conv2",
+ "mid_block.resnets.0.conv1",
+ "mid_block.resnets.0.conv2",
+ "mid_block.resnets.1.conv1",
+ "mid_block.resnets.1.conv2"
+ ]
+
+ for name, module in self.unet.named_modules():
+ if name in dilate_layers:
+ if i < dilate_tau:
+ module.dilation = (dilate_coef, dilate_coef)
+ module.padding = (dilate_coef, dilate_coef)
+ else:
+ module.dilation = (1, 1)
+ module.padding = (1, 1)
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ for name, module in self.unet.named_modules():
+ # if ('.conv' in name) and ('.conv_' not in name):
+ if name in dilate_layers:
+ module.dilation = (1, 1)
+ module.padding = (1, 1)
+
+ results_list.append(latents)
+
+ """
+ final_results = []
+ for latents in results_list:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if not return_dict:
+ final_results += [(image,)]
+ else:
+ final_results += [StableDiffusionXLPipelineOutput(images=image)]
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ return final_results
+ """
+ return StableDiffusionXLPipelineOutput(images=results_list)
+
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
+ # it here explicitly to be able to tell that it's coming from an SDXL
+ # pipeline.
+ state_dict, network_alphas = self.lora_state_dict(
+ pretrained_model_name_or_path_or_dict,
+ unet_config=self.unet.config,
+ **kwargs,
+ )
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
+
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
+ if len(text_encoder_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ prefix="text_encoder",
+ lora_scale=self.lora_scale,
+ )
+
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
+ if len(text_encoder_2_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_2_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder_2,
+ prefix="text_encoder_2",
+ lora_scale=self.lora_scale,
+ )
+
+ @classmethod
+ def save_lora_weights(
+ self,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ state_dict = {}
+
+ def pack_weights(layers, prefix):
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
+ return layers_state_dict
+
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
+
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+
+ self.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def _remove_text_encoder_monkey_patch(self):
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
diff --git a/modules/freescale/freescale_pipeline_img2img.py b/modules/freescale/freescale_pipeline_img2img.py
new file mode 100644
index 000000000..df4c3f0c1
--- /dev/null
+++ b/modules/freescale/freescale_pipeline_img2img.py
@@ -0,0 +1,1245 @@
+from inspect import isfunction
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+import inspect
+import os
+import random
+
+from PIL import Image
+import numpy as np
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+import torchvision.transforms as transforms
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention_processor import AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.utils import is_accelerate_available, is_accelerate_version, logging, replace_example_docstring
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.models.attention import BasicTransformerBlock
+
+from .scale_attention import ori_forward, scale_forward
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLPipeline
+
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+def process_image_to_tensor(image):
+ image = image.convert("RGB")
+ # image = Image.open(image_path).convert("RGB")
+ transform = transforms.Compose(
+ [
+ # transforms.Resize((1024, 1024)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ]
+ )
+ image_tensor = transform(image)
+ return image_tensor
+
+def process_image_to_bitensor(image):
+ # image = Image.open(image_path).convert("L")
+ image = image.convert("L")
+ transform = transforms.ToTensor()
+ image_tensor = transform(image)
+ binary_tensor = torch.where(image_tensor != 0, torch.tensor(1.0), torch.tensor(0.0))
+ return binary_tensor
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+def exists(val):
+ return val is not None
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+to_torch = partial(torch.tensor, dtype=torch.float16)
+betas = make_beta_schedule("linear", 1000, linear_start=0.00085, linear_end=0.012)
+alphas = 1. - betas
+alphas_cumprod = np.cumprod(alphas, axis=0)
+sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod))
+sqrt_one_minus_alphas_cumprod = to_torch(np.sqrt(1. - alphas_cumprod))
+
+def q_sample(x_start, t, init_noise_sigma = 1.0, noise=None, device=None):
+ noise = default(noise, lambda: torch.randn_like(x_start)).to(device) * init_noise_sigma
+ return (extract_into_tensor(sqrt_alphas_cumprod.to(device), t, x_start.shape) * x_start +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(device), t, x_start.shape) * noise)
+
+def get_views(height, width, h_window_size=128, w_window_size=128, h_window_stride=64, w_window_stride=64, vae_scale_factor=8):
+ height //= vae_scale_factor
+ width //= vae_scale_factor
+ num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
+ num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
+ views = []
+ for i in range(total_num_blocks):
+ h_start = int((i // num_blocks_width) * h_window_stride)
+ h_end = h_start + h_window_size
+ w_start = int((i % num_blocks_width) * w_window_stride)
+ w_end = w_start + w_window_size
+
+ if h_end > height:
+ h_start = int(h_start + height - h_end)
+ h_end = int(height)
+ if w_end > width:
+ w_start = int(w_start + width - w_end)
+ w_end = int(width)
+ if h_start < 0:
+ h_end = int(h_end - h_start)
+ h_start = 0
+ if w_start < 0:
+ w_end = int(w_end - w_start)
+ w_start = 0
+
+ random_jitter = True
+ if random_jitter:
+ h_jitter_range = (h_window_size - h_window_stride) // 4
+ w_jitter_range = (w_window_size - w_window_stride) // 4
+ h_jitter = 0
+ w_jitter = 0
+
+ if (w_start != 0) and (w_end != width):
+ w_jitter = random.randint(-w_jitter_range, w_jitter_range)
+ elif (w_start == 0) and (w_end != width):
+ w_jitter = random.randint(-w_jitter_range, 0)
+ elif (w_start != 0) and (w_end == width):
+ w_jitter = random.randint(0, w_jitter_range)
+ if (h_start != 0) and (h_end != height):
+ h_jitter = random.randint(-h_jitter_range, h_jitter_range)
+ elif (h_start == 0) and (h_end != height):
+ h_jitter = random.randint(-h_jitter_range, 0)
+ elif (h_start != 0) and (h_end == height):
+ h_jitter = random.randint(0, h_jitter_range)
+ h_start += (h_jitter + h_jitter_range)
+ h_end += (h_jitter + h_jitter_range)
+ w_start += (w_jitter + w_jitter_range)
+ w_end += (w_jitter + w_jitter_range)
+
+ views.append((h_start, h_end, w_start, w_end))
+ return views
+
+def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
+ x_coord = torch.arange(kernel_size)
+ gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
+ gaussian_1d = gaussian_1d / gaussian_1d.sum()
+ gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
+ kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
+
+ return kernel
+
+def gaussian_filter(latents, kernel_size=3, sigma=1.0):
+ channels = latents.shape[1]
+ kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
+ if len(latents.shape) == 5:
+ b = latents.shape[0]
+ latents = rearrange(latents, 'b c t i j -> (b t) c i j')
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
+ blurred_latents = rearrange(blurred_latents, '(b t) c i j -> b c t i j', b=b)
+ else:
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
+
+ return blurred_latents
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLFreeScaleImg2Img(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.default_sample_size = self.unet.config.sample_size
+
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ model_sequence = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+ model_sequence.extend([self.unet, self.vae])
+
+ hook = None
+ for cpu_offloaded_model in model_sequence:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ resolutions_list: Optional[Union[int, List[int]]] = None,
+ restart_steps: Optional[Union[int, List[int]]] = None,
+ cosine_scale: float = 2.0,
+ cosine_scale_bg: float = 1.0,
+ dilate_tau: int = 35,
+ img_path: Optional[str] = "",
+ mask_path: Optional[str] = "",
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+
+ # 0. Default height and width to unet
+ if resolutions_list:
+ height, width = resolutions_list[0]
+ target_sizes = resolutions_list[1:]
+ if not restart_steps:
+ restart_steps = [15] * len(target_sizes)
+ else:
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 9.1 Apply denoising_end
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ results_list = []
+
+ for block in self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks:
+ for module in block.modules():
+ if isinstance(module, BasicTransformerBlock):
+ module.forward = ori_forward.__get__(module, BasicTransformerBlock)
+
+ if img_path != '':
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ input_image = process_image_to_tensor(img_path).unsqueeze(0).to(dtype=self.vae.dtype, device=device)
+ latents = self.vae.encode(input_image).latent_dist.sample().to(self.vae.dtype)
+ latents = latents * self.vae.config.scaling_factor
+ else:
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ results_list.append(latents)
+
+ if mask_path != '':
+ mask = process_image_to_bitensor(mask_path).unsqueeze(0)
+
+ for restart_index, target_size in enumerate(target_sizes):
+ restart_step = restart_steps[restart_index]
+ target_size_ = [target_size[0]//8, target_size[1]//8]
+
+ for block in self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks:
+ for module in block.modules():
+ if isinstance(module, BasicTransformerBlock):
+ module.forward = scale_forward.__get__(module, BasicTransformerBlock)
+ module.current_hw = target_size
+
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ latents = latents / self.vae.config.scaling_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = torch.nn.functional.interpolate(
+ image,
+ size=target_size,
+ mode='bicubic',
+ )
+ latents = self.vae.encode(image).latent_dist.sample().to(self.vae.dtype)
+ latents = latents * self.vae.config.scaling_factor
+
+ if mask_path != '':
+ mask_ = torch.nn.functional.interpolate(
+ mask,
+ size=target_size_,
+ mode="nearest",
+ ).to(device)
+
+ noise_latents = []
+ noise = torch.randn_like(latents)
+ for timestep in self.scheduler.timesteps:
+ noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0))
+ noise_latents.append(noise_latent)
+ latents = noise_latents[restart_step]
+
+ self.scheduler._step_index = 0
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+
+ if i < restart_step:
+ self.scheduler._step_index += 1
+ progress_bar.update()
+ continue
+
+ cosine_factor = 0.5 * (1 + torch.cos(torch.pi * (self.scheduler.config.num_train_timesteps - t) / self.scheduler.config.num_train_timesteps)).cpu()
+ if mask_path != '':
+ c1 = (cosine_factor ** (mask_ * cosine_scale + (1-mask_) * cosine_scale_bg)).to(dtype=torch.float16)
+ else:
+ c1 = cosine_factor ** cosine_scale
+ latents = latents * (1 - c1) + noise_latents[i] * c1
+
+ dilate_coef=target_size[1]//1024
+
+ dilate_layers = [
+ # "down_blocks.1.resnets.0.conv1",
+ # "down_blocks.1.resnets.0.conv2",
+ # "down_blocks.1.resnets.1.conv1",
+ # "down_blocks.1.resnets.1.conv2",
+ "down_blocks.1.downsamplers.0.conv",
+ "down_blocks.2.resnets.0.conv1",
+ "down_blocks.2.resnets.0.conv2",
+ "down_blocks.2.resnets.1.conv1",
+ "down_blocks.2.resnets.1.conv2",
+ # "up_blocks.0.resnets.0.conv1",
+ # "up_blocks.0.resnets.0.conv2",
+ # "up_blocks.0.resnets.1.conv1",
+ # "up_blocks.0.resnets.1.conv2",
+ # "up_blocks.0.resnets.2.conv1",
+ # "up_blocks.0.resnets.2.conv2",
+ # "up_blocks.0.upsamplers.0.conv",
+ # "up_blocks.1.resnets.0.conv1",
+ # "up_blocks.1.resnets.0.conv2",
+ # "up_blocks.1.resnets.1.conv1",
+ # "up_blocks.1.resnets.1.conv2",
+ # "up_blocks.1.resnets.2.conv1",
+ # "up_blocks.1.resnets.2.conv2",
+ # "up_blocks.1.upsamplers.0.conv",
+ # "up_blocks.2.resnets.0.conv1",
+ # "up_blocks.2.resnets.0.conv2",
+ # "up_blocks.2.resnets.1.conv1",
+ # "up_blocks.2.resnets.1.conv2",
+ # "up_blocks.2.resnets.2.conv1",
+ # "up_blocks.2.resnets.2.conv2",
+ "mid_block.resnets.0.conv1",
+ "mid_block.resnets.0.conv2",
+ "mid_block.resnets.1.conv1",
+ "mid_block.resnets.1.conv2"
+ ]
+
+ for name, module in self.unet.named_modules():
+ if name in dilate_layers:
+ if i < dilate_tau:
+ module.dilation = (dilate_coef, dilate_coef)
+ module.padding = (dilate_coef, dilate_coef)
+ else:
+ module.dilation = (1, 1)
+ module.padding = (1, 1)
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ for name, module in self.unet.named_modules():
+ # if ('.conv' in name) and ('.conv_' not in name):
+ if name in dilate_layers:
+ module.dilation = (1, 1)
+ module.padding = (1, 1)
+
+ results_list.append(latents)
+
+ """
+ final_results = []
+ for latents in results_list:
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if not return_dict:
+ final_results += [(image,)]
+ else:
+ final_results += [StableDiffusionXLPipelineOutput(images=image)]
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ return final_results
+ """
+ return StableDiffusionXLPipelineOutput(images=results_list)
+
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
+ # it here explicitly to be able to tell that it's coming from an SDXL
+ # pipeline.
+ state_dict, network_alphas = self.lora_state_dict(
+ pretrained_model_name_or_path_or_dict,
+ unet_config=self.unet.config,
+ **kwargs,
+ )
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
+
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
+ if len(text_encoder_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ prefix="text_encoder",
+ lora_scale=self.lora_scale,
+ )
+
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
+ if len(text_encoder_2_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_2_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder_2,
+ prefix="text_encoder_2",
+ lora_scale=self.lora_scale,
+ )
+
+ @classmethod
+ def save_lora_weights(
+ self,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ state_dict = {}
+
+ def pack_weights(layers, prefix):
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
+ return layers_state_dict
+
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
+
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+
+ self.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def _remove_text_encoder_monkey_patch(self):
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
diff --git a/modules/freescale/scale_attention.py b/modules/freescale/scale_attention.py
new file mode 100644
index 000000000..9e83d5067
--- /dev/null
+++ b/modules/freescale/scale_attention.py
@@ -0,0 +1,367 @@
+from typing import Any, Dict, Optional
+import random
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+
+
+def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
+ x_coord = torch.arange(kernel_size)
+ gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
+ gaussian_1d = gaussian_1d / gaussian_1d.sum()
+ gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
+ kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
+
+ return kernel
+
+def gaussian_filter(latents, kernel_size=3, sigma=1.0):
+ channels = latents.shape[1]
+ kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
+
+ return blurred_latents
+
+def get_views(height, width, h_window_size=128, w_window_size=128, scale_factor=8):
+ height = int(height)
+ width = int(width)
+ h_window_stride = h_window_size // 2
+ w_window_stride = w_window_size // 2
+ h_window_size = int(h_window_size / scale_factor)
+ w_window_size = int(w_window_size / scale_factor)
+ h_window_stride = int(h_window_stride / scale_factor)
+ w_window_stride = int(w_window_stride / scale_factor)
+ num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
+ num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
+ views = []
+ for i in range(total_num_blocks):
+ h_start = int((i // num_blocks_width) * h_window_stride)
+ h_end = h_start + h_window_size
+ w_start = int((i % num_blocks_width) * w_window_stride)
+ w_end = w_start + w_window_size
+
+ if h_end > height:
+ h_start = int(h_start + height - h_end)
+ h_end = int(height)
+ if w_end > width:
+ w_start = int(w_start + width - w_end)
+ w_end = int(width)
+ if h_start < 0:
+ h_end = int(h_end - h_start)
+ h_start = 0
+ if w_start < 0:
+ w_end = int(w_end - w_start)
+ w_start = 0
+
+ random_jitter = True
+ if random_jitter:
+ h_jitter_range = h_window_size // 8
+ w_jitter_range = w_window_size // 8
+ h_jitter = 0
+ w_jitter = 0
+
+ if (w_start != 0) and (w_end != width):
+ w_jitter = random.randint(-w_jitter_range, w_jitter_range)
+ elif (w_start == 0) and (w_end != width):
+ w_jitter = random.randint(-w_jitter_range, 0)
+ elif (w_start != 0) and (w_end == width):
+ w_jitter = random.randint(0, w_jitter_range)
+ if (h_start != 0) and (h_end != height):
+ h_jitter = random.randint(-h_jitter_range, h_jitter_range)
+ elif (h_start == 0) and (h_end != height):
+ h_jitter = random.randint(-h_jitter_range, 0)
+ elif (h_start != 0) and (h_end == height):
+ h_jitter = random.randint(0, h_jitter_range)
+ h_start += (h_jitter + h_jitter_range)
+ h_end += (h_jitter + h_jitter_range)
+ w_start += (w_jitter + w_jitter_range)
+ w_end += (w_jitter + w_jitter_range)
+
+ views.append((h_start, h_end, w_start, w_end))
+ return views
+
+def scale_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+):
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ if self.current_hw:
+ current_scale_num_h, current_scale_num_w = max(self.current_hw[0] // 1024, 1), max(self.current_hw[1] // 1024, 1)
+ else:
+ current_scale_num_h, current_scale_num_w = 1, 1
+
+ # 0. Self-Attention
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ ratio_hw = current_scale_num_h / current_scale_num_w
+ latent_h = int((norm_hidden_states.shape[1] * ratio_hw) ** 0.5)
+ latent_w = int(latent_h / ratio_hw)
+ scale_factor = 128 * current_scale_num_h / latent_h
+ if ratio_hw > 1:
+ sub_h = 128
+ sub_w = int(128 / ratio_hw)
+ else:
+ sub_h = int(128 * ratio_hw)
+ sub_w = 128
+
+ h_jitter_range = int(sub_h / scale_factor // 8)
+ w_jitter_range = int(sub_w / scale_factor // 8)
+ views = get_views(latent_h, latent_w, sub_h, sub_w, scale_factor = scale_factor)
+
+ current_scale_num = max(current_scale_num_h, current_scale_num_w)
+ global_views = [[h, w] for h in range(current_scale_num_h) for w in range(current_scale_num_w)]
+
+ four_window = True
+ fourg_window = False
+
+ if four_window:
+ norm_hidden_states_ = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
+ norm_hidden_states_ = F.pad(norm_hidden_states_, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
+ value = torch.zeros_like(norm_hidden_states_)
+ count = torch.zeros_like(norm_hidden_states_)
+ for index, view in enumerate(views):
+ h_start, h_end, w_start, w_end = view
+ local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
+ local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
+ local_output = self.attn1(
+ local_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))
+
+ value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
+ count[:, h_start:h_end, w_start:w_end, :] += 1
+
+ value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
+ count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
+ attn_output = torch.where(count>0, value/count, value)
+
+ gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
+
+ attn_output_global = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
+
+ gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
+
+ attn_output = gaussian_local + (attn_output_global - gaussian_global)
+ attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
+
+ elif fourg_window:
+ norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
+ norm_hidden_states_ = F.pad(norm_hidden_states, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
+ value = torch.zeros_like(norm_hidden_states_)
+ count = torch.zeros_like(norm_hidden_states_)
+ for index, view in enumerate(views):
+ h_start, h_end, w_start, w_end = view
+ local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
+ local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
+ local_output = self.attn1(
+ local_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))
+
+ value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
+ count[:, h_start:h_end, w_start:w_end, :] += 1
+
+ value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
+ count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
+ attn_output = torch.where(count>0, value/count, value)
+
+ gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
+
+ value = torch.zeros_like(norm_hidden_states)
+ count = torch.zeros_like(norm_hidden_states)
+ for index, global_view in enumerate(global_views):
+ h, w = global_view
+ global_states = norm_hidden_states[:, h::current_scale_num_h, w::current_scale_num_w, :]
+ global_states = rearrange(global_states, 'bh h w d -> bh (h w) d')
+ global_output = self.attn1(
+ global_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ global_output = rearrange(global_output, 'bh (h w) d -> bh h w d', h = int(global_output.shape[1] ** 0.5))
+
+ value[:, h::current_scale_num_h, w::current_scale_num_w, :] += global_output * 1
+ count[:, h::current_scale_num_h, w::current_scale_num_w, :] += 1
+
+ attn_output_global = torch.where(count>0, value/count, value)
+
+ gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
+
+ attn_output = gaussian_local + (attn_output_global - gaussian_global)
+ attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
+
+ else:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+ # 2.5 ends
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [
+ self.ff(hid_slice)
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
+ ],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+def ori_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+):
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+ # 2.5 ends
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [
+ self.ff(hid_slice)
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
+ ],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
diff --git a/modules/ggml/__init__.py b/modules/ggml/__init__.py
index acac6057c..44721d846 100644
--- a/modules/ggml/__init__.py
+++ b/modules/ggml/__init__.py
@@ -1,11 +1,35 @@
-from pathlib import Path
+import os
+import time
import torch
-import gguf
-from .gguf_utils import TORCH_COMPATIBLE_QTYPES
-from .gguf_tensor import GGMLTensor
+import diffusers
+import transformers
-def load_gguf_state_dict(path: str, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]:
+def install_gguf():
+ # pip install git+https://github.com/junejae/transformers@feature/t5-gguf
+ # https://github.com/ggerganov/llama.cpp/issues/9566
+ from installer import install
+ install('gguf', quiet=True)
+ import importlib
+ import gguf
+ from modules import shared
+ scripts_dir = os.path.join(os.path.dirname(gguf.__file__), '..', 'scripts')
+ if os.path.exists(scripts_dir):
+ os.rename(scripts_dir, scripts_dir + str(time.time()))
+ # monkey patch transformers/diffusers so they detect newly installed gguf pacakge correctly
+ ver = importlib.metadata.version('gguf')
+ transformers.utils.import_utils._is_gguf_available = True # pylint: disable=protected-access
+ transformers.utils.import_utils._gguf_version = ver # pylint: disable=protected-access
+ diffusers.utils.import_utils._is_gguf_available = True # pylint: disable=protected-access
+ diffusers.utils.import_utils._gguf_version = ver # pylint: disable=protected-access
+ shared.log.debug(f'Load GGUF: version={ver}')
+ return gguf
+
+
+def load_gguf_state_dict(path: str, compute_dtype: torch.dtype) -> dict:
+ gguf = install_gguf()
+ from .gguf_utils import TORCH_COMPATIBLE_QTYPES
+ from .gguf_tensor import GGMLTensor
sd: dict[str, GGMLTensor] = {}
stats = {}
reader = gguf.GGUFReader(path)
@@ -19,3 +43,14 @@ def load_gguf_state_dict(path: str, compute_dtype: torch.dtype) -> dict[str, GGM
stats[tensor.tensor_type.name] = 0
stats[tensor.tensor_type.name] += 1
return sd, stats
+
+
+def load_gguf(path, cls, compute_dtype: torch.dtype):
+ _gguf = install_gguf()
+ module = cls.from_single_file(
+ path,
+ quantization_config = diffusers.GGUFQuantizationConfig(compute_dtype=compute_dtype),
+ torch_dtype=compute_dtype,
+ )
+ module.gguf = 'gguf'
+ return module
diff --git a/modules/ggml/gguf_tensor.py b/modules/ggml/gguf_tensor.py
index 4bc9117cb..8b2f608ac 100644
--- a/modules/ggml/gguf_tensor.py
+++ b/modules/ggml/gguf_tensor.py
@@ -131,7 +131,6 @@ def get_dequantized_tensor(self):
if self._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES:
return self.quantized_data.to(self.compute_dtype)
elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS:
- # TODO(ryand): Look into how the dtype param is intended to be used.
return dequantize(
data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self.tensor_shape, dtype=None
).to(self.compute_dtype)
diff --git a/modules/hashes.py b/modules/hashes.py
index cf83794b0..a003f4840 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -9,6 +9,13 @@
cache_data = None
progress_ok = True
+
+def init_cache():
+ global cache_data # pylint: disable=global-statement
+ if cache_data is None:
+ cache_data = {} if not os.path.isfile(cache_filename) else shared.readfile(cache_filename, lock=True)
+
+
def dump_cache():
shared.writefile(cache_data, cache_filename)
diff --git a/modules/hidiffusion/hidiffusion.py b/modules/hidiffusion/hidiffusion.py
index 7874f03af..d6f68eb15 100644
--- a/modules/hidiffusion/hidiffusion.py
+++ b/modules/hidiffusion/hidiffusion.py
@@ -234,7 +234,7 @@ def window_reverse(windows, window_size, H, W, shift_size):
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) # pylint: disable=undefined-variable # TODO hidiffusion undefined
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) # pylint: disable=undefined-variable
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
@@ -308,7 +308,7 @@ def forward(
self.T1 = int(self.max_timestep * self.T1_ratio)
output_states = ()
- _scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # TODO hidiffusion unused
+ _scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
blocks = list(zip(self.resnets, self.attentions))
@@ -407,7 +407,7 @@ def forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
- def fix_scale(first, second): # TODO hidiffusion breaks hidden_scale.shape on 3rd generate with sdxl
+ def fix_scale(first, second):
if (first.shape[-1] != second.shape[-1] or first.shape[-2] != second.shape[-2]):
rescale = min(second.shape[-2] / first.shape[-2], second.shape[-1] / first.shape[-1])
# log.debug(f"HiDiffusion rescale: {hidden_states.shape} => {res_hidden_states_tuple[0].shape} scale={rescale}")
diff --git a/modules/images.py b/modules/images.py
index 910349bef..2cfbe941d 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -267,6 +267,8 @@ def safe_decode_string(s: bytes):
def read_info_from_image(image: Image, watermark: bool = False):
+ if image is None:
+ return '', {}
items = image.info or {}
geninfo = items.pop('parameters', None) or items.pop('UserComment', None)
if geninfo is not None and len(geninfo) > 0:
diff --git a/modules/images_namegen.py b/modules/images_namegen.py
index bc58f728a..7cad39b92 100644
--- a/modules/images_namegen.py
+++ b/modules/images_namegen.py
@@ -45,12 +45,12 @@ class FilenameGenerator:
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
'sampler': lambda self: self.p and self.p.sampler_name,
- 'seed': lambda self: self.seed and str(self.seed) or '',
+ 'seed': lambda self: (self.seed and str(self.seed)) or '',
'steps': lambda self: self.p and getattr(self.p, 'steps', 0),
'cfg': lambda self: self.p and getattr(self.p, 'cfg_scale', 0),
'clip_skip': lambda self: self.p and getattr(self.p, 'clip_skip', 0),
'denoising': lambda self: self.p and getattr(self.p, 'denoising_strength', 0),
- 'styles': lambda self: self.p and ", ".join([style for style in self.p.styles if not style == "None"]) or "None",
+ 'styles': lambda self: (self.p and ", ".join([style for style in self.p.styles if not style == "None"])) or "None",
'uuid': lambda self: str(uuid.uuid4()),
}
default_time_format = '%Y%m%d%H%M%S'
diff --git a/modules/images_resize.py b/modules/images_resize.py
index d86ff6f22..362be79ee 100644
--- a/modules/images_resize.py
+++ b/modules/images_resize.py
@@ -5,13 +5,13 @@
from modules import shared
-def resize_image(resize_mode, im, width, height, upscaler_name=None, output_type='image', context=None):
+def resize_image(resize_mode: int, im: Image.Image, width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None):
upscaler_name = upscaler_name or shared.opts.upscaler_for_img2img
def latent(im, w, h, upscaler):
from modules.processing_vae import vae_encode, vae_decode
import torch
- latents = vae_encode(im, shared.sd_model, full_quality=False) # TODO enable full VAE mode for resize-latent
+ latents = vae_encode(im, shared.sd_model, full_quality=False) # TODO resize image: enable full VAE mode for resize-latent
latents = torch.nn.functional.interpolate(latents, size=(int(h // 8), int(w // 8)), mode=upscaler["mode"], antialias=upscaler["antialias"])
im = vae_decode(latents, shared.sd_model, output_type='pil', full_quality=False)[0]
return im
@@ -79,18 +79,18 @@ def fill(im, color=None):
def context_aware(im, width, height, context):
import seam_carving # https://github.com/li-plus/seam-carving
- if 'forward' in context:
+ if 'forward' in context.lower():
energy_mode = "forward"
- elif 'backward' in context:
+ elif 'backward' in context.lower():
energy_mode = "backward"
else:
return im
- if 'Add' in context:
+ if 'add' in context.lower():
src_ratio = min(width / im.width, height / im.height)
src_w = int(im.width * src_ratio)
src_h = int(im.height * src_ratio)
src_image = resize(im, src_w, src_h)
- elif 'Remove' in context:
+ elif 'remove' in context.lower():
ratio = width / height
src_ratio = im.width / im.height
src_w = width if ratio > src_ratio else im.width * height // im.height
@@ -122,7 +122,7 @@ def context_aware(im, width, height, context):
from modules import masking
res = fill(im, color=0)
res, _mask = masking.outpaint(res)
- elif resize_mode == 5: # context-aware
+ elif resize_mode == 5: # context-aware
res = context_aware(im, width, height, context)
else:
res = im.copy()
diff --git a/modules/img2img.py b/modules/img2img.py
index 8274386cc..2e3eca54d 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -1,6 +1,7 @@
import os
import itertools # SBM Batch frames
import numpy as np
+import filetype
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
import modules.scripts
from modules import shared, processing, images
@@ -8,7 +9,6 @@
from modules.ui import plaintext_to_html
from modules.memstats import memory_stats
-
debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: PROCESS')
@@ -16,24 +16,25 @@
def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args):
shared.log.debug(f'batch: {input_files}|{input_dir}|{output_dir}|{inpaint_mask_dir}')
processing.fix_seed(p)
+ image_files = []
if input_files is not None and len(input_files) > 0:
image_files = [f.name for f in input_files]
- else:
- if not os.path.isdir(input_dir):
- shared.log.error(f"Process batch: directory not found: {input_dir}")
- return
- image_files = os.listdir(input_dir)
- image_files = [os.path.join(input_dir, f) for f in image_files]
+ image_files = [f for f in image_files if filetype.is_image(f)]
+ shared.log.info(f'Process batch: input images={len(image_files)}')
+ elif os.path.isdir(input_dir):
+ image_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir)]
+ image_files = [f for f in image_files if filetype.is_image(f)]
+ shared.log.info(f'Process batch: input folder="{input_dir}" images={len(image_files)}')
is_inpaint_batch = False
- if inpaint_mask_dir:
- inpaint_masks = os.listdir(inpaint_mask_dir)
- inpaint_masks = [os.path.join(inpaint_mask_dir, f) for f in inpaint_masks]
+ if inpaint_mask_dir and os.path.isdir(inpaint_mask_dir):
+ inpaint_masks = [os.path.join(inpaint_mask_dir, f) for f in os.listdir(inpaint_mask_dir)]
+ inpaint_masks = [f for f in inpaint_masks if filetype.is_image(f)]
is_inpaint_batch = len(inpaint_masks) > 0
- if is_inpaint_batch:
- shared.log.info(f"Process batch: inpaint batch masks={len(inpaint_masks)}")
+ shared.log.info(f'Process batch: mask folder="{input_dir}" images={len(inpaint_masks)}')
save_normally = output_dir == ''
p.do_not_save_grid = True
p.do_not_save_samples = not save_normally
+ p.default_prompt = p.prompt
shared.state.job_count = len(image_files) * p.n_iter
if shared.opts.batch_frame_mode: # SBM Frame mode is on, process each image in batch with same seed
window_size = p.batch_size
@@ -55,14 +56,29 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args)
for image_file in batch_image_files:
try:
img = Image.open(image_file)
- if p.scale_by != 1:
- p.width = int(img.width * p.scale_by)
- p.height = int(img.height * p.scale_by)
+ img = ImageOps.exif_transpose(img)
+ batch_images.append(img)
+ # p.init()
+ p.width = int(img.width * p.scale_by)
+ p.height = int(img.height * p.scale_by)
+ caption_file = os.path.splitext(image_file)[0] + '.txt'
+ prompt_type='default'
+ if os.path.exists(caption_file):
+ with open(caption_file, 'r', encoding='utf8') as f:
+ p.prompt = f.read()
+ prompt_type='file'
+ else:
+ p.prompt = p.default_prompt
+ p.all_prompts = None
+ p.all_negative_prompts = None
+ p.all_seeds = None
+ p.all_subseeds = None
+ shared.log.debug(f'Process batch: image="{image_file}" prompt={prompt_type} i={i+1}/{len(image_files)}')
except UnidentifiedImageError as e:
- shared.log.error(f"Image error: {e}")
- continue
- img = ImageOps.exif_transpose(img)
- batch_images.append(img)
+ shared.log.error(f'Process batch: image="{image_file}" {e}')
+ if len(batch_images) == 0:
+ shared.log.warning("Process batch: no images found in batch")
+ continue
batch_images = batch_images * btcrept # Standard mode sends the same image per batchsize.
p.init_images = batch_images
@@ -81,17 +97,20 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args)
batch_image_files = batch_image_files * btcrept # List used for naming later.
- proc = modules.scripts.scripts_img2img.run(p, *args)
- if proc is None:
- proc = processing.process_images(p)
- for n, (image, image_file) in enumerate(itertools.zip_longest(proc.images,batch_image_files)):
+ processed = modules.scripts.scripts_img2img.run(p, *args)
+ if processed is None:
+ processed = processing.process_images(p)
+
+ for n, (image, image_file) in enumerate(itertools.zip_longest(processed.images, batch_image_files)):
+ if image is None:
+ continue
basename = ''
if shared.opts.use_original_name_batch:
forced_filename, ext = os.path.splitext(os.path.basename(image_file))
else:
forced_filename = None
ext = shared.opts.samples_format
- if len(proc.images) > 1:
+ if len(processed.images) > 1:
basename = f'{n + i}' if shared.opts.batch_frame_mode else f'{n}'
else:
basename = ''
@@ -103,7 +122,7 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args)
for k, v in items.items():
image.info[k] = v
images.save_image(image, path=output_dir, basename=basename, seed=None, prompt=None, extension=ext, info=geninfo, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=image.info, forced_filename=forced_filename)
- proc = modules.scripts.scripts_img2img.after(p, proc, *args)
+ processed = modules.scripts.scripts_img2img.after(p, processed, *args)
shared.log.debug(f'Processed: images={len(batch_image_files)} memory={memory_stats()} batch')
@@ -147,29 +166,20 @@ def img2img(id_task: str, state: str, mode: int,
debug(f'img2img: id_task={id_task}|mode={mode}|prompt={prompt}|negative_prompt={negative_prompt}|prompt_styles={prompt_styles}|init_img={init_img}|sketch={sketch}|init_img_with_mask={init_img_with_mask}|inpaint_color_sketch={inpaint_color_sketch}|inpaint_color_sketch_orig={inpaint_color_sketch_orig}|init_img_inpaint={init_img_inpaint}|init_mask_inpaint={init_mask_inpaint}|steps={steps}|sampler_index={sampler_index}||mask_blur={mask_blur}|mask_alpha={mask_alpha}|inpainting_fill={inpainting_fill}|full_quality={full_quality}|detailer={detailer}|tiling={tiling}|hidiffusion={hidiffusion}|n_iter={n_iter}|batch_size={batch_size}|cfg_scale={cfg_scale}|image_cfg_scale={image_cfg_scale}|clip_skip={clip_skip}|denoising_strength={denoising_strength}|seed={seed}|subseed{subseed}|subseed_strength={subseed_strength}|seed_resize_from_h={seed_resize_from_h}|seed_resize_from_w={seed_resize_from_w}|selected_scale_tab={selected_scale_tab}|height={height}|width={width}|scale_by={scale_by}|resize_mode={resize_mode}|resize_name={resize_name}|resize_context={resize_context}|inpaint_full_res={inpaint_full_res}|inpaint_full_res_padding={inpaint_full_res_padding}|inpainting_mask_invert={inpainting_mask_invert}|img2img_batch_files={img2img_batch_files}|img2img_batch_input_dir={img2img_batch_input_dir}|img2img_batch_output_dir={img2img_batch_output_dir}|img2img_batch_inpaint_mask_dir={img2img_batch_inpaint_mask_dir}|override_settings_texts={override_settings_texts}')
- if mode == 5:
- if img2img_batch_files is None or len(img2img_batch_files) == 0:
- shared.log.debug('Init bactch images not set')
- elif init_img:
- shared.log.debug('Init image not set')
-
if sampler_index is None:
shared.log.warning('Sampler: invalid')
sampler_index = 0
+ mode = int(mode)
+ image = None
+ mask = None
override_settings = create_override_settings_dict(override_settings_texts)
- if mode == 0: # img2img
+ if mode == 0: # img2img
if init_img is None:
return [], '', '', 'Error: init image not provided'
image = init_img.convert("RGB")
- mask = None
- elif mode == 1: # img2img sketch
- if sketch is None:
- return [], '', '', 'Error: sketch image not provided'
- image = sketch.convert("RGB")
- mask = None
- elif mode == 2: # inpaint
+ elif mode == 1: # inpaint
if init_img_with_mask is None:
return [], '', '', 'Error: init image with mask not provided'
image = init_img_with_mask["image"]
@@ -177,7 +187,11 @@ def img2img(id_task: str, state: str, mode: int,
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
image = image.convert("RGB")
- elif mode == 3: # inpaint sketch
+ elif mode == 2: # sketch
+ if sketch is None:
+ return [], '', '', 'Error: sketch image not provided'
+ image = sketch.convert("RGB")
+ elif mode == 3: # composite
if inpaint_color_sketch is None:
return [], '', '', 'Error: color sketch image not provided'
image = inpaint_color_sketch
@@ -188,15 +202,16 @@ def img2img(id_task: str, state: str, mode: int,
blur = ImageFilter.GaussianBlur(mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB")
- elif mode == 4: # inpaint upload mask
+ elif mode == 4: # inpaint upload mask
if init_img_inpaint is None:
return [], '', '', 'Error: inpaint image not provided'
image = init_img_inpaint
mask = init_mask_inpaint
+ elif mode == 5: # process batch
+ pass # handled later
else:
shared.log.error(f'Image processing unknown mode: {mode}')
- image = None
- mask = None
+
if image is not None:
image = ImageOps.exif_transpose(image)
if selected_scale_tab == 1 and resize_mode != 0:
diff --git a/modules/infotext.py b/modules/infotext.py
index 05e06e600..baa995c88 100644
--- a/modules/infotext.py
+++ b/modules/infotext.py
@@ -10,6 +10,7 @@
debug = lambda *args, **kwargs: None # pylint: disable=unnecessary-lambda-assignment
re_size = re.compile(r"^(\d+)x(\d+)$") # int x int
re_param = re.compile(r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)') # multi-word: value
+re_lora = re.compile("= 2.3:
torch.cuda._initialization_lock = torch.xpu._initialization_lock
torch.cuda._initialized = torch.xpu._initialized
torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork
@@ -122,7 +122,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.traceback = torch.xpu.traceback
# Memory:
- if legacy and 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
+ if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
@@ -159,7 +159,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
torch.cuda.amp = torch.xpu.amp
- if not ipex.__version__.startswith("2.3"):
+ if float(ipex.__version__[:3]) < 2.3:
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
@@ -178,7 +178,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
# C
- if legacy and not ipex.__version__.startswith("2.3"):
+ if legacy and float(ipex.__version__[:3]) < 2.3:
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12
diff --git a/modules/intel/ipex/diffusers.py b/modules/intel/ipex/diffusers.py
index f742fe5c0..5bf5bbe39 100644
--- a/modules/intel/ipex/diffusers.py
+++ b/modules/intel/ipex/diffusers.py
@@ -1,7 +1,7 @@
import os
from functools import wraps, cache
import torch
-import diffusers #0.29.1 # pylint: disable=import-error
+import diffusers # pylint: disable=import-error
from diffusers.models.attention_processor import Attention
# pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -20,20 +20,31 @@ def fourier_filter(x_in, threshold, scale):
# fp64 error
-def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
- assert dim % 2 == 0, "The dimension must be even."
-
- scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim # force fp32 instead of fp64
- omega = 1.0 / (theta**scale)
-
- batch_size, seq_length = pos.shape
- out = torch.einsum("...n,d->...nd", pos, omega)
- cos_out = torch.cos(out)
- sin_out = torch.sin(out)
-
- stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
- out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
- return out.float()
+class FluxPosEmbed(torch.nn.Module):
+ def __init__(self, theta: int, axes_dim):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ for i in range(n_axes):
+ cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[:, i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=torch.float32,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
@cache
@@ -337,4 +348,5 @@ def ipex_diffusers():
if not device_supports_fp64 or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
- diffusers.models.transformers.transformer_flux.rope = rope
+ if not device_supports_fp64:
+ diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
diff --git a/modules/intel/ipex/hijacks.py b/modules/intel/ipex/hijacks.py
index 7ec94138d..b1c9a1182 100644
--- a/modules/intel/ipex/hijacks.py
+++ b/modules/intel/ipex/hijacks.py
@@ -149,6 +149,15 @@ def functional_linear(input, weight, bias=None):
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_linear(input, weight, bias=bias)
+original_functional_conv1d = torch.nn.functional.conv1d
+@wraps(torch.nn.functional.conv1d)
+def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ if input.dtype != weight.data.dtype:
+ input = input.to(dtype=weight.data.dtype)
+ if bias is not None and bias.data.dtype != weight.data.dtype:
+ bias.data = bias.data.to(dtype=weight.data.dtype)
+ return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+
original_functional_conv2d = torch.nn.functional.conv2d
@wraps(torch.nn.functional.conv2d)
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
@@ -158,6 +167,16 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+# LTX Video
+original_functional_conv3d = torch.nn.functional.conv3d
+@wraps(torch.nn.functional.conv3d)
+def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ if input.dtype != weight.data.dtype:
+ input = input.to(dtype=weight.data.dtype)
+ if bias is not None and bias.data.dtype != weight.data.dtype:
+ bias.data = bias.data.to(dtype=weight.data.dtype)
+ return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+
# SwinIR BF16:
original_functional_pad = torch.nn.functional.pad
@wraps(torch.nn.functional.pad)
@@ -294,7 +313,7 @@ def torch_load(f, map_location=None, *args, **kwargs):
# Hijack Functions:
def ipex_hijacks(legacy=True):
- if legacy:
+ if legacy and float(torch.__version__[:3]) < 2.5:
torch.nn.functional.interpolate = interpolate
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
@@ -320,7 +339,9 @@ def ipex_hijacks(legacy=True):
torch.nn.functional.group_norm = functional_group_norm
torch.nn.functional.layer_norm = functional_layer_norm
torch.nn.functional.linear = functional_linear
+ torch.nn.functional.conv1d = functional_conv1d
torch.nn.functional.conv2d = functional_conv2d
+ torch.nn.functional.conv3d = functional_conv3d
torch.nn.functional.pad = functional_pad
torch.bmm = torch_bmm
diff --git a/modules/ipadapter.py b/modules/ipadapter.py
index d5bfbec8c..aa010c33d 100644
--- a/modules/ipadapter.py
+++ b/modules/ipadapter.py
@@ -9,11 +9,16 @@
import time
import json
from PIL import Image
-from modules import processing, shared, devices, sd_models
+import diffusers
+import transformers
+from modules import processing, shared, devices, sd_models, errors
-clip_repo = "h94/IP-Adapter"
clip_loaded = None
+adapters_loaded = []
+CLIP_ID = "h94/IP-Adapter"
+OPEN_ID = "openai/clip-vit-large-patch14"
+SIGLIP_ID = 'google/siglip-so400m-patch14-384'
ADAPTERS_NONE = {
'None': { 'name': 'none', 'repo': 'none', 'subfolder': 'none' },
}
@@ -36,11 +41,13 @@
'Ostris Composition ViT-H SDXL': { 'name': 'ip_plus_composition_sdxl.safetensors', 'repo': 'ostris/ip-composition-adapter', 'subfolder': '' },
}
ADAPTERS_SD3 = {
- 'InstantX Large': { 'name': 'ip-adapter.bin', 'repo': 'InstantX/SD3.5-Large-IP-Adapter' },
+ 'None': { 'name': 'none', 'repo': 'none', 'subfolder': 'none' },
+ 'InstantX Large': { 'name': 'none', 'repo': 'InstantX/SD3.5-Large-IP-Adapter', 'subfolder': 'none', 'revision': 'refs/pr/10' },
}
ADAPTERS_F1 = {
- 'XLabs AI v1': { 'name': 'ip_adapter.safetensors', 'repo': 'XLabs-AI/flux-ip-adapter' },
- 'XLabs AI v2': { 'name': 'ip_adapter.safetensors', 'repo': 'XLabs-AI/flux-ip-adapter-v2' },
+ 'None': { 'name': 'none', 'repo': 'none', 'subfolder': 'none' },
+ 'XLabs AI v1': { 'name': 'ip_adapter.safetensors', 'repo': 'XLabs-AI/flux-ip-adapter', 'subfolder': 'none' },
+ 'XLabs AI v2': { 'name': 'ip_adapter.safetensors', 'repo': 'XLabs-AI/flux-ip-adapter-v2', 'subfolder': 'none' },
}
ADAPTERS = { **ADAPTERS_SD15, **ADAPTERS_SDXL, **ADAPTERS_SD3, **ADAPTERS_F1 }
ADAPTERS_ALL = { **ADAPTERS_SD15, **ADAPTERS_SDXL, **ADAPTERS_SD3, **ADAPTERS_F1 }
@@ -125,14 +132,27 @@ def crop_images(images, crops):
shared.log.error(f'IP adapter: failed to crop image: source={len(images[i])} faces={len(cropped)}')
except Exception as e:
shared.log.error(f'IP adapter: failed to crop image: {e}')
+ if shared.sd_model_type == 'sd3' and len(images) == 1:
+ return images[0]
return images
-def unapply(pipe): # pylint: disable=arguments-differ
+def unapply(pipe, unload: bool = False): # pylint: disable=arguments-differ
+ if len(adapters_loaded) == 0:
+ return
try:
if hasattr(pipe, 'set_ip_adapter_scale'):
pipe.set_ip_adapter_scale(0)
- if hasattr(pipe, 'unet') and hasattr(pipe.unet, 'config') and pipe.unet.config.encoder_hid_dim_type == 'ip_image_proj':
+ if unload:
+ shared.log.debug('IP adapter unload')
+ pipe.unload_ip_adapter()
+ if hasattr(pipe, 'unet'):
+ module = pipe.unet
+ elif hasattr(pipe, 'transformer'):
+ module = pipe.transformer
+ else:
+ module = None
+ if module is not None and hasattr(module, 'config') and module.config.encoder_hid_dim_type == 'ip_image_proj':
pipe.unet.encoder_hid_proj = None
pipe.config.encoder_hid_dim_type = None
pipe.unet.set_default_attn_processor()
@@ -140,27 +160,78 @@ def unapply(pipe): # pylint: disable=arguments-differ
pass
-def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapter_scales=[1.0], adapter_crops=[False], adapter_starts=[0.0], adapter_ends=[1.0], adapter_images=[]):
+def load_image_encoder(pipe: diffusers.DiffusionPipeline, adapter_names: list[str]):
global clip_loaded # pylint: disable=global-statement
- # overrides
- if hasattr(p, 'ip_adapter_names'):
- if isinstance(p.ip_adapter_names, str):
- p.ip_adapter_names = [p.ip_adapter_names]
- adapters = [ADAPTERS_ALL.get(adapter_name, None) for adapter_name in p.ip_adapter_names if adapter_name is not None and adapter_name.lower() != 'none']
- adapter_names = p.ip_adapter_names
- else:
- if isinstance(adapter_names, str):
- adapter_names = [adapter_names]
- adapters = [ADAPTERS.get(adapter, None) for adapter in adapter_names]
- adapters = [adapter for adapter in adapters if adapter is not None and adapter['name'].lower() != 'none']
- if len(adapters) == 0:
- unapply(pipe)
- if hasattr(p, 'ip_adapter_images'):
- del p.ip_adapter_images
- return False
- if shared.sd_model_type not in ['sd', 'sdxl', 'sd3', 'f1']:
- shared.log.error(f'IP adapter: model={shared.sd_model_type} class={pipe.__class__.__name__} not supported')
- return False
+ for adapter_name in adapter_names:
+ # which clip to use
+ clip_repo = CLIP_ID
+ if 'ViT' not in adapter_name: # defaults per model
+ clip_subfolder = 'models/image_encoder' if shared.sd_model_type == 'sd' else 'sdxl_models/image_encoder'
+ if 'ViT-H' in adapter_name:
+ clip_subfolder = 'models/image_encoder' # this is vit-h
+ elif 'ViT-G' in adapter_name:
+ clip_subfolder = 'sdxl_models/image_encoder' # this is vit-g
+ else:
+ if shared.sd_model_type == 'sd':
+ clip_subfolder = 'models/image_encoder'
+ elif shared.sd_model_type == 'sdxl':
+ clip_subfolder = 'sdxl_models/image_encoder'
+ elif shared.sd_model_type == 'sd3':
+ clip_repo = SIGLIP_ID
+ clip_subfolder = None
+ elif shared.sd_model_type == 'f1':
+ clip_repo = OPEN_ID
+ clip_subfolder = None
+ else:
+ shared.log.error(f'IP adapter: unknown model type: {adapter_name}')
+ return False
+
+ # load image encoder used by ip adapter
+ if pipe.image_encoder is None or clip_loaded != f'{clip_repo}/{clip_subfolder}':
+ try:
+ if shared.sd_model_type == 'sd3':
+ image_encoder = transformers.SiglipVisionModel.from_pretrained(clip_repo, torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir)
+ else:
+ if clip_subfolder is None:
+ image_encoder = transformers.CLIPVisionModelWithProjection.from_pretrained(clip_repo, torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, use_safetensors=True)
+ shared.log.debug(f'IP adapter load: encoder="{clip_repo}" cls={pipe.image_encoder.__class__.__name__}')
+ else:
+ image_encoder = transformers.CLIPVisionModelWithProjection.from_pretrained(clip_repo, subfolder=clip_subfolder, torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, use_safetensors=True)
+ shared.log.debug(f'IP adapter load: encoder="{clip_repo}/{clip_subfolder}" cls={pipe.image_encoder.__class__.__name__}')
+ if hasattr(pipe, 'register_modules'):
+ pipe.register_modules(image_encoder=image_encoder)
+ else:
+ pipe.image_encoder = image_encoder
+ clip_loaded = f'{clip_repo}/{clip_subfolder}'
+ except Exception as e:
+ shared.log.error(f'IP adapter load: encoder="{clip_repo}/{clip_subfolder}" {e}')
+ errors.display(e, 'IP adapter: type=encoder')
+ return False
+ sd_models.move_model(pipe.image_encoder, devices.device)
+ return True
+
+
+def load_feature_extractor(pipe):
+ # load feature extractor used by ip adapter
+ if pipe.feature_extractor is None:
+ try:
+ if shared.sd_model_type == 'sd3':
+ feature_extractor = transformers.SiglipImageProcessor.from_pretrained(SIGLIP_ID, torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir)
+ else:
+ feature_extractor = transformers.CLIPImageProcessor()
+ if hasattr(pipe, 'register_modules'):
+ pipe.register_modules(feature_extractor=feature_extractor)
+ else:
+ pipe.feature_extractor = feature_extractor
+ shared.log.debug(f'IP adapter load: extractor={pipe.feature_extractor.__class__.__name__}')
+ except Exception as e:
+ shared.log.error(f'IP adapter load: extractor {e}')
+ errors.display(e, 'IP adapter: type=extractor')
+ return False
+ return True
+
+
+def parse_params(p: processing.StableDiffusionProcessing, adapters: list, adapter_scales: list[float], adapter_crops: list[bool], adapter_starts: list[float], adapter_ends: list[float], adapter_images: list):
if hasattr(p, 'ip_adapter_scales'):
adapter_scales = p.ip_adapter_scales
if hasattr(p, 'ip_adapter_crops'):
@@ -201,6 +272,33 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt
p.ip_adapter_starts = adapter_starts.copy()
adapter_ends = get_scales(adapter_ends, adapter_images)
p.ip_adapter_ends = adapter_ends.copy()
+ return adapter_images, adapter_masks, adapter_scales, adapter_crops, adapter_starts, adapter_ends
+
+
+def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapter_scales=[1.0], adapter_crops=[False], adapter_starts=[0.0], adapter_ends=[1.0], adapter_images=[]):
+ global adapters_loaded # pylint: disable=global-statement
+ # overrides
+ if hasattr(p, 'ip_adapter_names'):
+ if isinstance(p.ip_adapter_names, str):
+ p.ip_adapter_names = [p.ip_adapter_names]
+ adapters = [ADAPTERS_ALL.get(adapter_name, None) for adapter_name in p.ip_adapter_names if adapter_name is not None and adapter_name.lower() != 'none']
+ adapter_names = p.ip_adapter_names
+ else:
+ if isinstance(adapter_names, str):
+ adapter_names = [adapter_names]
+ adapters = [ADAPTERS.get(adapter_name, None) for adapter_name in adapter_names if adapter_name.lower() != 'none']
+
+ if len(adapters) == 0:
+ unapply(pipe, getattr(p, 'ip_adapter_unload', False))
+ if hasattr(p, 'ip_adapter_images'):
+ del p.ip_adapter_images
+ return False
+ if shared.sd_model_type not in ['sd', 'sdxl', 'sd3', 'f1']:
+ shared.log.error(f'IP adapter: model={shared.sd_model_type} class={pipe.__class__.__name__} not supported')
+ return False
+
+ adapter_images, adapter_masks, adapter_scales, adapter_crops, adapter_starts, adapter_ends = parse_params(p, adapters, adapter_scales, adapter_crops, adapter_starts, adapter_ends, adapter_images)
+
# init code
if pipe is None:
return False
@@ -211,7 +309,7 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt
shared.log.error('IP adapter: no image provided')
adapters = [] # unload adapter if previously loaded as it will cause runtime errors
if len(adapters) == 0:
- unapply(pipe)
+ unapply(pipe, getattr(p, 'ip_adapter_unload', False))
if hasattr(p, 'ip_adapter_images'):
del p.ip_adapter_images
return False
@@ -219,61 +317,30 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt
shared.log.error(f'IP adapter: pipeline not supported: {pipe.__class__.__name__}')
return False
- for adapter_name in adapter_names:
- # which clip to use
- if 'ViT' not in adapter_name: # defaults per model
- if shared.sd_model_type == 'sd':
- clip_subfolder = 'models/image_encoder'
- else:
- clip_subfolder = 'sdxl_models/image_encoder'
- if 'ViT-H' in adapter_name:
- clip_subfolder = 'models/image_encoder' # this is vit-h
- elif 'ViT-G' in adapter_name:
- clip_subfolder = 'sdxl_models/image_encoder' # this is vit-g
- else:
- if shared.sd_model_type == 'sd':
- clip_subfolder = 'models/image_encoder'
- elif shared.sd_model_type == 'sdxl':
- clip_subfolder = 'sdxl_models/image_encoder'
- elif shared.sd_model_type == 'sd3':
- shared.log.error(f'IP adapter: adapter={adapter_name} type={shared.sd_model_type} cls={shared.sd_model.__class__.__name__}: unsupported base model')
- return False
- elif shared.sd_model_type == 'f1':
- shared.log.error(f'IP adapter: adapter={adapter_name} type={shared.sd_model_type} cls={shared.sd_model.__class__.__name__}: unsupported base model')
- return False
- else:
- shared.log.error(f'IP adapter: unknown model type: {adapter_name}')
- return False
-
- # load feature extractor used by ip adapter
- if pipe.feature_extractor is None:
- try:
- from transformers import CLIPImageProcessor
- shared.log.debug('IP adapter load: feature extractor')
- pipe.feature_extractor = CLIPImageProcessor()
- except Exception as e:
- shared.log.error(f'IP adapter load: feature extractor {e}')
- return False
+ if not load_image_encoder(pipe, adapter_names):
+ return False
- # load image encoder used by ip adapter
- if pipe.image_encoder is None or clip_loaded != f'{clip_repo}/{clip_subfolder}':
- try:
- from transformers import CLIPVisionModelWithProjection
- shared.log.debug(f'IP adapter load: image encoder="{clip_repo}/{clip_subfolder}"')
- pipe.image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_repo, subfolder=clip_subfolder, torch_dtype=devices.dtype, cache_dir=shared.opts.diffusers_dir, use_safetensors=True)
- clip_loaded = f'{clip_repo}/{clip_subfolder}'
- except Exception as e:
- shared.log.error(f'IP adapter load: image encoder="{clip_repo}/{clip_subfolder}" {e}')
- return False
- sd_models.move_model(pipe.image_encoder, devices.device)
+ if not load_feature_extractor(pipe):
+ return False
# main code
try:
t0 = time.time()
- repos = [adapter['repo'] for adapter in adapters]
- subfolders = [adapter['subfolder'] for adapter in adapters]
- names = [adapter['name'] for adapter in adapters]
- pipe.load_ip_adapter(repos, subfolder=subfolders, weight_name=names)
+ repos = [adapter.get('repo', None) for adapter in adapters if adapter.get('repo', 'none') != 'none']
+ subfolders = [adapter.get('subfolder', None) for adapter in adapters if adapter.get('subfolder', 'none') != 'none']
+ names = [adapter.get('name', None) for adapter in adapters if adapter.get('name', 'none') != 'none']
+ revisions = [adapter.get('revision', None) for adapter in adapters if adapter.get('revision', 'none') != 'none']
+ kwargs = {}
+ if len(repos) == 1:
+ repos = repos[0]
+ if len(subfolders) > 0:
+ kwargs['subfolder'] = subfolders if len(subfolders) > 1 else subfolders[0]
+ if len(names) > 0:
+ kwargs['weight_name'] = names if len(names) > 1 else names[0]
+ if len(revisions) > 0:
+ kwargs['revision'] = revisions[0]
+ pipe.load_ip_adapter(repos, **kwargs)
+ adapters_loaded = names
if hasattr(p, 'ip_adapter_layers'):
pipe.set_ip_adapter_scale(p.ip_adapter_layers)
ip_str = ';'.join(adapter_names) + ':' + json.dumps(p.ip_adapter_layers)
@@ -281,8 +348,8 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt
for i in range(len(adapter_scales)):
if adapter_starts[i] > 0:
adapter_scales[i] = 0.00
- pipe.set_ip_adapter_scale(adapter_scales)
- ip_str = [f'{os.path.splitext(adapter)[0]}:{scale}:{start}:{end}' for adapter, scale, start, end in zip(adapter_names, adapter_scales, adapter_starts, adapter_ends)]
+ pipe.set_ip_adapter_scale(adapter_scales if len(adapter_scales) > 1 else adapter_scales[0])
+ ip_str = [f'{os.path.splitext(adapter)[0]}:{scale}:{start}:{end}:{crop}' for adapter, scale, start, end, crop in zip(adapter_names, adapter_scales, adapter_starts, adapter_ends, adapter_crops)]
p.task_args['ip_adapter_image'] = crop_images(adapter_images, adapter_crops)
if len(adapter_masks) > 0:
p.cross_attention_kwargs = { 'ip_adapter_masks': adapter_masks }
@@ -291,4 +358,5 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt
shared.log.info(f'IP adapter: {ip_str} image={adapter_images} mask={adapter_masks is not None} time={t1-t0:.2f}')
except Exception as e:
shared.log.error(f'IP adapter load: adapters={adapter_names} repo={repos} folders={subfolders} names={names} {e}')
+ errors.display(e, 'IP adapter: type=adapter')
return True
diff --git a/modules/loader.py b/modules/loader.py
index cd51cc8eb..38d942fd4 100644
--- a/modules/loader.py
+++ b/modules/loader.py
@@ -14,7 +14,7 @@
logging.getLogger("DeepSpeed").disabled = True
-os.environ.setdefault('TORCH_LOGS', '-all')
+# os.environ.setdefault('TORCH_LOGS', '-all')
import torch # pylint: disable=C0411
if torch.__version__.startswith('2.5.0'):
errors.log.warning(f'Disabling cuDNN for SDP on torch={torch.__version__}')
diff --git a/modules/lora/extra_networks_lora.py b/modules/lora/extra_networks_lora.py
new file mode 100644
index 000000000..42c4a92f6
--- /dev/null
+++ b/modules/lora/extra_networks_lora.py
@@ -0,0 +1,152 @@
+import re
+import numpy as np
+import modules.lora.networks as networks
+from modules import extra_networks, shared
+
+
+# from https://github.com/cheald/sd-webui-loractl/blob/master/loractl/lib/utils.py
+def get_stepwise(param, step, steps):
+ def sorted_positions(raw_steps):
+ steps = [[float(s.strip()) for s in re.split("[@~]", x)]
+ for x in re.split("[,;]", str(raw_steps))]
+ if len(steps[0]) == 1: # If we just got a single number, just return it
+ return steps[0][0]
+ steps = [[s[0], s[1] if len(s) == 2 else 1] for s in steps] # Add implicit 1s to any steps which don't have a weight
+ steps.sort(key=lambda k: k[1]) # Sort by index
+ steps = [list(v) for v in zip(*steps)]
+ return steps
+
+ def calculate_weight(m, step, max_steps, step_offset=2):
+ if isinstance(m, list):
+ if m[1][-1] <= 1.0:
+ step = step / (max_steps - step_offset) if max_steps > 0 else 1.0
+ v = np.interp(step, m[1], m[0])
+ return v
+ else:
+ return m
+
+ stepwise = calculate_weight(sorted_positions(param), step, steps)
+ return stepwise
+
+
+def prompt(p):
+ if shared.opts.lora_apply_tags == 0:
+ return
+ all_tags = []
+ for loaded in networks.loaded_networks:
+ page = [en for en in shared.extra_networks if en.name == 'lora'][0]
+ item = page.create_item(loaded.name)
+ tags = (item or {}).get("tags", {})
+ loaded.tags = list(tags)
+ if len(loaded.tags) == 0:
+ loaded.tags.append(loaded.name)
+ if shared.opts.lora_apply_tags > 0:
+ loaded.tags = loaded.tags[:shared.opts.lora_apply_tags]
+ all_tags.extend(loaded.tags)
+ if len(all_tags) > 0:
+ all_tags = list(set(all_tags))
+ all_tags = [t for t in all_tags if t not in p.prompt]
+ shared.log.debug(f"Load network: type=LoRA tags={all_tags} max={shared.opts.lora_apply_tags} apply")
+ all_tags = ', '.join(all_tags)
+ p.extra_generation_params["LoRA tags"] = all_tags
+ if '_tags_' in p.prompt:
+ p.prompt = p.prompt.replace('_tags_', all_tags)
+ else:
+ p.prompt = f"{p.prompt}, {all_tags}"
+ if p.all_prompts is not None:
+ for i in range(len(p.all_prompts)):
+ if '_tags_' in p.all_prompts[i]:
+ p.all_prompts[i] = p.all_prompts[i].replace('_tags_', all_tags)
+ else:
+ p.all_prompts[i] = f"{p.all_prompts[i]}, {all_tags}"
+
+
+def infotext(p):
+ names = [i.name for i in networks.loaded_networks]
+ if len(names) > 0:
+ p.extra_generation_params["LoRA networks"] = ", ".join(names)
+ if shared.opts.lora_add_hashes_to_infotext:
+ network_hashes = []
+ for item in networks.loaded_networks:
+ if not item.network_on_disk.shorthash:
+ continue
+ network_hashes.append(item.network_on_disk.shorthash)
+ if len(network_hashes) > 0:
+ p.extra_generation_params["LoRA hashes"] = ", ".join(network_hashes)
+
+
+def parse(p, params_list, step=0):
+ names = []
+ te_multipliers = []
+ unet_multipliers = []
+ dyn_dims = []
+ for params in params_list:
+ assert params.items
+ names.append(params.positional[0])
+ te_multiplier = params.named.get("te", params.positional[1] if len(params.positional) > 1 else shared.opts.extra_networks_default_multiplier)
+ if isinstance(te_multiplier, str) and "@" in te_multiplier:
+ te_multiplier = get_stepwise(te_multiplier, step, p.steps)
+ else:
+ te_multiplier = float(te_multiplier)
+ unet_multiplier = [params.positional[2] if len(params.positional) > 2 else te_multiplier] * 3
+ unet_multiplier = [params.named.get("unet", unet_multiplier[0])] * 3
+ unet_multiplier[0] = params.named.get("in", unet_multiplier[0])
+ unet_multiplier[1] = params.named.get("mid", unet_multiplier[1])
+ unet_multiplier[2] = params.named.get("out", unet_multiplier[2])
+ for i in range(len(unet_multiplier)):
+ if isinstance(unet_multiplier[i], str) and "@" in unet_multiplier[i]:
+ unet_multiplier[i] = get_stepwise(unet_multiplier[i], step, p.steps)
+ else:
+ unet_multiplier[i] = float(unet_multiplier[i])
+ dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
+ dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
+ te_multipliers.append(te_multiplier)
+ unet_multipliers.append(unet_multiplier)
+ dyn_dims.append(dyn_dim)
+ return names, te_multipliers, unet_multipliers, dyn_dims
+
+
+class ExtraNetworkLora(extra_networks.ExtraNetwork):
+
+ def __init__(self):
+ super().__init__('lora')
+ self.active = False
+ self.model = None
+ self.errors = {}
+
+ def activate(self, p, params_list, step=0, include=[], exclude=[]):
+ self.errors.clear()
+ if self.active:
+ if self.model != shared.opts.sd_model_checkpoint: # reset if model changed
+ self.active = False
+ if len(params_list) > 0 and not self.active: # activate patches once
+ # shared.log.debug(f'Activate network: type=LoRA model="{shared.opts.sd_model_checkpoint}"')
+ self.active = True
+ self.model = shared.opts.sd_model_checkpoint
+ if 'text_encoder' in include:
+ networks.timer.clear(complete=True)
+ names, te_multipliers, unet_multipliers, dyn_dims = parse(p, params_list, step)
+ networks.network_load(names, te_multipliers, unet_multipliers, dyn_dims) # load
+ networks.network_activate(include, exclude)
+ if len(networks.loaded_networks) > 0 and len(networks.applied_layers) > 0 and step == 0:
+ infotext(p)
+ prompt(p)
+ shared.log.info(f'Load network: type=LoRA apply={[n.name for n in networks.loaded_networks]} mode={"fuse" if shared.opts.lora_fuse_diffusers else "backup"} te={te_multipliers} unet={unet_multipliers} time={networks.timer.summary}')
+
+ def deactivate(self, p):
+ if shared.native and len(networks.diffuser_loaded) > 0:
+ if hasattr(shared.sd_model, "unload_lora_weights") and hasattr(shared.sd_model, "text_encoder"):
+ if not (shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled is True):
+ try:
+ if shared.opts.lora_fuse_diffusers:
+ shared.sd_model.unfuse_lora()
+ shared.sd_model.unload_lora_weights() # fails for non-CLIP models
+ except Exception:
+ pass
+ networks.network_deactivate()
+ if self.active and networks.debug:
+ shared.log.debug(f"Network end: type=LoRA time={networks.timer.summary}")
+ if self.errors:
+ for k, v in self.errors.items():
+ shared.log.error(f'LoRA: name="{k}" errors={v}')
+ self.errors.clear()
diff --git a/modules/lora/lora_convert.py b/modules/lora/lora_convert.py
new file mode 100644
index 000000000..032ffa5a3
--- /dev/null
+++ b/modules/lora/lora_convert.py
@@ -0,0 +1,509 @@
+import os
+import re
+import bisect
+from typing import Dict
+import torch
+from modules import shared
+
+
+debug = os.environ.get('SD_LORA_DEBUG', None) is not None
+suffix_conversion = {
+ "attentions": {},
+ "resnets": {
+ "conv1": "in_layers_2",
+ "conv2": "out_layers_3",
+ "norm1": "in_layers_0",
+ "norm2": "out_layers_0",
+ "time_emb_proj": "emb_layers_1",
+ "conv_shortcut": "skip_connection",
+ }
+}
+re_digits = re.compile(r"\d+")
+re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
+re_compiled = {}
+
+
+def make_unet_conversion_map() -> Dict[str, str]:
+ unet_conversion_map_layer = []
+
+ for i in range(3): # num_blocks is 3 in sdxl
+ # loop over downblocks/upblocks
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+ if i < 3:
+ # no attention layers in down_blocks.3
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(3):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+ # if i > 0: commentout for sdxl
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{2}." # change for sdxl
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ hf_mid_atn_prefix = "mid_block.attentions.0."
+ sd_mid_atn_prefix = "middle_block.1."
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+ for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+ unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0.", "norm1."),
+ ("in_layers.2.", "conv1."),
+ ("out_layers.0.", "norm2."),
+ ("out_layers.3.", "conv2."),
+ ("emb_layers.1.", "time_emb_proj."),
+ ("skip_connection.", "conv_shortcut."),
+ ]
+
+ unet_conversion_map = []
+ for sd, hf in unet_conversion_map_layer:
+ if "resnets" in hf:
+ for sd_res, hf_res in unet_conversion_map_resnet:
+ unet_conversion_map.append((sd + sd_res, hf + hf_res))
+ else:
+ unet_conversion_map.append((sd, hf))
+
+ for j in range(2):
+ hf_time_embed_prefix = f"time_embedding.linear_{j + 1}."
+ sd_time_embed_prefix = f"time_embed.{j * 2}."
+ unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
+
+ for j in range(2):
+ hf_label_embed_prefix = f"add_embedding.linear_{j + 1}."
+ sd_label_embed_prefix = f"label_emb.0.{j * 2}."
+ unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
+
+ unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
+ unet_conversion_map.append(("out.0.", "conv_norm_out."))
+ unet_conversion_map.append(("out.2.", "conv_out."))
+
+ sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
+ return sd_hf_conversion_map
+
+
+class KeyConvert:
+ def __init__(self):
+ self.is_sdxl = True if shared.sd_model_type == "sdxl" else False
+ self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None
+ self.LORA_PREFIX_UNET = "lora_unet_"
+ self.LORA_PREFIX_TEXT_ENCODER = "lora_te_"
+ self.OFT_PREFIX_UNET = "oft_unet_"
+ # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
+ self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1_"
+ self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2_"
+
+ def __call__(self, key):
+ if self.is_sdxl:
+ if "diffusion_model" in key: # Fix NTC Slider naming error
+ key = key.replace("diffusion_model", "lora_unet")
+ map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
+ map_keys.sort()
+ search_key = key.replace(self.LORA_PREFIX_UNET, "").replace(self.OFT_PREFIX_UNET, "").replace(self.LORA_PREFIX_TEXT_ENCODER1, "").replace(self.LORA_PREFIX_TEXT_ENCODER2, "")
+ position = bisect.bisect_right(map_keys, search_key)
+ map_key = map_keys[position - 1]
+ if search_key.startswith(map_key):
+ key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]).replace("oft", "lora") # pylint: disable=unsubscriptable-object
+ if "lycoris" in key and "transformer" in key:
+ key = key.replace("lycoris", "lora_transformer")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+ if sd_module is None:
+ sd_module = shared.sd_model.network_layer_mapping.get(key.replace("guidance", "timestep"), None) # FLUX1 fix
+ if debug and sd_module is None:
+ raise RuntimeError(f"LoRA key not found in network_layer_mapping: key={key} mapping={shared.sd_model.network_layer_mapping.keys()}")
+ return key, sd_module
+
+
+# Taken from https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py
+# Modified from 'lora_A' and 'lora_B' to 'lora_down' and 'lora_up'
+# Added early exit
+# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
+# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
+# All credits go to `kohya-ss`.
+def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
+ if sds_key + ".lora_down.weight" not in sds_sd:
+ return
+ down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
+
+ # scale weight by alpha and dim
+ rank = down_weight.shape[0]
+ alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
+
+ # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
+ scale_down = scale
+ scale_up = 1.0
+ while scale_down * 2 < scale_up:
+ scale_down *= 2
+ scale_up /= 2
+
+ ait_sd[ait_key + ".lora_down.weight"] = down_weight * scale_down
+ ait_sd[ait_key + ".lora_up.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
+
+def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
+ if sds_key + ".lora_down.weight" not in sds_sd:
+ return
+ down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
+ up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
+ sd_lora_rank = down_weight.shape[0]
+
+ # scale weight by alpha and dim
+ alpha = sds_sd.pop(sds_key + ".alpha")
+ scale = alpha / sd_lora_rank
+
+ # calculate scale_down and scale_up
+ scale_down = scale
+ scale_up = 1.0
+ while scale_down * 2 < scale_up:
+ scale_down *= 2
+ scale_up /= 2
+
+ down_weight = down_weight * scale_down
+ up_weight = up_weight * scale_up
+
+ # calculate dims if not provided
+ num_splits = len(ait_keys)
+ if dims is None:
+ dims = [up_weight.shape[0] // num_splits] * num_splits
+ else:
+ assert sum(dims) == up_weight.shape[0]
+
+ # check upweight is sparse or not
+ is_sparse = False
+ if sd_lora_rank % num_splits == 0:
+ ait_rank = sd_lora_rank // num_splits
+ is_sparse = True
+ i = 0
+ for j in range(len(dims)):
+ for k in range(len(dims)):
+ if j == k:
+ continue
+ is_sparse = is_sparse and torch.all(
+ up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
+ )
+ i += dims[j]
+ # if is_sparse:
+ # print(f"weight is sparse: {sds_key}")
+
+ # make ai-toolkit weight
+ ait_down_keys = [k + ".lora_down.weight" for k in ait_keys]
+ ait_up_keys = [k + ".lora_up.weight" for k in ait_keys]
+ if not is_sparse:
+ # down_weight is copied to each split
+ ait_sd.update({k: down_weight for k in ait_down_keys})
+
+ # up_weight is split to each split
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension
+ else:
+ # down_weight is chunked to each split
+ ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416 # pylint: disable=unnecessary-comprehension
+
+ # up_weight is sparse: only non-zero values are copied to each split
+ i = 0
+ for j in range(len(dims)):
+ ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
+ i += dims[j]
+
+def _convert_text_encoder_lora_key(key, lora_name):
+ """
+ Converts a text encoder LoRA key to a Diffusers compatible key.
+ """
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
+ key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
+ else:
+ key_to_replace = "lora_te2_"
+
+ diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
+ diffusers_name = diffusers_name.replace("text.projection", "text_projection")
+
+ if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
+ pass
+ elif "mlp" in diffusers_name:
+ # Be aware that this is the new diffusers convention and the rest of the code might
+ # not utilize it yet.
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
+ return diffusers_name
+
+def _convert_kohya_flux_lora_to_diffusers(state_dict):
+ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
+ ait_sd = {}
+ for i in range(19):
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_double_blocks_{i}_img_attn_proj",
+ f"transformer.transformer_blocks.{i}.attn.to_out.0",
+ )
+ _convert_to_ai_toolkit_cat(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_double_blocks_{i}_img_attn_qkv",
+ [
+ f"transformer.transformer_blocks.{i}.attn.to_q",
+ f"transformer.transformer_blocks.{i}.attn.to_k",
+ f"transformer.transformer_blocks.{i}.attn.to_v",
+ ],
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_double_blocks_{i}_img_mlp_0",
+ f"transformer.transformer_blocks.{i}.ff.net.0.proj",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_double_blocks_{i}_img_mlp_2",
+ f"transformer.transformer_blocks.{i}.ff.net.2",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_double_blocks_{i}_img_mod_lin",
+ f"transformer.transformer_blocks.{i}.norm1.linear",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_double_blocks_{i}_txt_attn_proj",
+ f"transformer.transformer_blocks.{i}.attn.to_add_out",
+ )
+ _convert_to_ai_toolkit_cat(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_double_blocks_{i}_txt_attn_qkv",
+ [
+ f"transformer.transformer_blocks.{i}.attn.add_q_proj",
+ f"transformer.transformer_blocks.{i}.attn.add_k_proj",
+ f"transformer.transformer_blocks.{i}.attn.add_v_proj",
+ ],
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_double_blocks_{i}_txt_mlp_0",
+ f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_double_blocks_{i}_txt_mlp_2",
+ f"transformer.transformer_blocks.{i}.ff_context.net.2",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_double_blocks_{i}_txt_mod_lin",
+ f"transformer.transformer_blocks.{i}.norm1_context.linear",
+ )
+
+ for i in range(38):
+ _convert_to_ai_toolkit_cat(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_single_blocks_{i}_linear1",
+ [
+ f"transformer.single_transformer_blocks.{i}.attn.to_q",
+ f"transformer.single_transformer_blocks.{i}.attn.to_k",
+ f"transformer.single_transformer_blocks.{i}.attn.to_v",
+ f"transformer.single_transformer_blocks.{i}.proj_mlp",
+ ],
+ dims=[3072, 3072, 3072, 12288],
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_single_blocks_{i}_linear2",
+ f"transformer.single_transformer_blocks.{i}.proj_out",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_single_blocks_{i}_modulation_lin",
+ f"transformer.single_transformer_blocks.{i}.norm.linear",
+ )
+
+ if len(sds_sd) > 0:
+ return None
+
+ return ait_sd
+
+ return _convert_sd_scripts_to_ai_toolkit(state_dict)
+
+def _convert_kohya_sd3_lora_to_diffusers(state_dict):
+ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
+ ait_sd = {}
+ for i in range(38):
+ _convert_to_ai_toolkit_cat(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_joint_blocks_{i}_context_block_attn_qkv",
+ [
+ f"transformer.transformer_blocks.{i}.attn.to_q",
+ f"transformer.transformer_blocks.{i}.attn.to_k",
+ f"transformer.transformer_blocks.{i}.attn.to_v",
+ ],
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_joint_blocks_{i}_context_block_mlp_fc1",
+ f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_joint_blocks_{i}_context_block_mlp_fc2",
+ f"transformer.transformer_blocks.{i}.ff_context.net.2",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_joint_blocks_{i}_x_block_mlp_fc1",
+ f"transformer.transformer_blocks.{i}.ff.net.0.proj",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_joint_blocks_{i}_x_block_mlp_fc2",
+ f"transformer.transformer_blocks.{i}.ff.net.2",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_joint_blocks_{i}_context_block_adaLN_modulation_1",
+ f"transformer.transformer_blocks.{i}.norm1_context.linear",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_joint_blocks_{i}_x_block_adaLN_modulation_1",
+ f"transformer.transformer_blocks.{i}.norm1.linear",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_joint_blocks_{i}_context_block_attn_proj",
+ f"transformer.transformer_blocks.{i}.attn.to_add_out",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_joint_blocks_{i}_x_block_attn_proj",
+ f"transformer.transformer_blocks.{i}.attn.to_out_0",
+ )
+
+ _convert_to_ai_toolkit_cat(
+ sds_sd,
+ ait_sd,
+ f"lora_unet_joint_blocks_{i}_x_block_attn_qkv",
+ [
+ f"transformer.transformer_blocks.{i}.attn.add_q_proj",
+ f"transformer.transformer_blocks.{i}.attn.add_k_proj",
+ f"transformer.transformer_blocks.{i}.attn.add_v_proj",
+ ],
+ )
+ remaining_keys = list(sds_sd.keys())
+ te_state_dict = {}
+ if remaining_keys:
+ if not all(k.startswith("lora_te1") for k in remaining_keys):
+ raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
+ for key in remaining_keys:
+ if not key.endswith("lora_down.weight"):
+ continue
+
+ lora_name = key.split(".")[0]
+ lora_name_up = f"{lora_name}.lora_up.weight"
+ lora_name_alpha = f"{lora_name}.alpha"
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
+
+ sd_lora_rank = 1
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
+ down_weight = sds_sd.pop(key)
+ sd_lora_rank = down_weight.shape[0]
+ te_state_dict[diffusers_name] = down_weight
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
+
+ if lora_name_alpha in sds_sd:
+ alpha = sds_sd.pop(lora_name_alpha).item()
+ scale = alpha / sd_lora_rank
+
+ scale_down = scale
+ scale_up = 1.0
+ while scale_down * 2 < scale_up:
+ scale_down *= 2
+ scale_up /= 2
+
+ te_state_dict[diffusers_name] *= scale_down
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
+
+ if len(sds_sd) > 0:
+ print(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
+
+ if te_state_dict:
+ te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
+
+ new_state_dict = {**ait_sd, **te_state_dict}
+ return new_state_dict
+
+ return _convert_sd_scripts_to_ai_toolkit(state_dict)
+
+
+def assign_network_names_to_compvis_modules(sd_model):
+ if sd_model is None:
+ return
+ sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility
+ network_layer_mapping = {}
+ if hasattr(sd_model, 'text_encoder') and sd_model.text_encoder is not None:
+ for name, module in sd_model.text_encoder.named_modules():
+ prefix = "lora_te1_" if hasattr(sd_model, 'text_encoder_2') else "lora_te_"
+ network_name = prefix + name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+ if hasattr(sd_model, 'text_encoder_2'):
+ for name, module in sd_model.text_encoder_2.named_modules():
+ network_name = "lora_te2_" + name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+ if hasattr(sd_model, 'unet'):
+ for name, module in sd_model.unet.named_modules():
+ network_name = "lora_unet_" + name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+ if hasattr(sd_model, 'transformer'):
+ for name, module in sd_model.transformer.named_modules():
+ network_name = "lora_transformer_" + name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ if "norm" in network_name and "linear" not in network_name and shared.sd_model_type != "sd3":
+ continue
+ module.network_layer_name = network_name
+ shared.sd_model.network_layer_mapping = network_layer_mapping
diff --git a/modules/lora/lora_extract.py b/modules/lora/lora_extract.py
new file mode 100644
index 000000000..58cd065bb
--- /dev/null
+++ b/modules/lora/lora_extract.py
@@ -0,0 +1,271 @@
+import os
+import time
+import json
+import datetime
+import torch
+from safetensors.torch import save_file
+import gradio as gr
+from rich import progress as p
+from modules import shared, devices
+from modules.ui_common import create_refresh_button
+from modules.call_queue import wrap_gradio_gpu_call
+
+
+class SVDHandler:
+ def __init__(self, maxrank=0, rank_ratio=1):
+ self.network_name: str = None
+ self.U: torch.Tensor = None
+ self.S: torch.Tensor = None
+ self.Vh: torch.Tensor = None
+ self.maxrank: int = maxrank
+ self.rank_ratio: float = rank_ratio
+ self.rank: int = 0
+ self.out_size: int = None
+ self.in_size: int = None
+ self.kernel_size: tuple[int, int] = None
+ self.conv2d: bool = False
+
+ def decompose(self, weight, backupweight):
+ self.conv2d = len(weight.size()) == 4
+ self.kernel_size = None if not self.conv2d else weight.size()[2:4]
+ self.out_size, self.in_size = weight.size()[0:2]
+ diffweight = weight.clone().to(devices.device)
+ diffweight -= backupweight.to(devices.device)
+ if self.conv2d:
+ if self.conv2d and self.kernel_size != (1, 1):
+ diffweight = diffweight.flatten(start_dim=1)
+ else:
+ diffweight = diffweight.squeeze()
+ self.U, self.S, self.Vh = torch.svd_lowrank(diffweight.to(device=devices.device, dtype=torch.float), self.maxrank, 2)
+ # del diffweight
+ self.U = self.U.to(device=devices.cpu, dtype=torch.bfloat16)
+ self.S = self.S.to(device=devices.cpu, dtype=torch.bfloat16)
+ self.Vh = self.Vh.t().to(device=devices.cpu, dtype=torch.bfloat16) # svd_lowrank outputs a transposed matrix
+
+ def findrank(self):
+ if self.rank_ratio < 1:
+ S_squared = self.S.pow(2)
+ S_fro_sq = float(torch.sum(S_squared))
+ sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
+ index = int(torch.searchsorted(sum_S_squared, self.rank_ratio ** 2)) + 1
+ index = max(1, min(index, len(self.S) - 1))
+ self.rank = index
+ if self.maxrank > 0:
+ self.rank = min(self.rank, self.maxrank)
+ else:
+ self.rank = min(self.in_size, self.out_size, self.maxrank)
+
+ def makeweights(self):
+ self.findrank()
+ up = self.U[:, :self.rank] @ torch.diag(self.S[:self.rank])
+ down = self.Vh[:self.rank, :]
+ if self.conv2d and self.kernel_size is not None:
+ up = up.reshape(self.out_size, self.rank, 1, 1)
+ down = down.reshape(self.rank, self.in_size, self.kernel_size[0], self.kernel_size[1]) # pylint: disable=unsubscriptable-object
+ return_dict = {f'{self.network_name}.lora_up.weight': up.contiguous(),
+ f'{self.network_name}.lora_down.weight': down.contiguous(),
+ f'{self.network_name}.alpha': torch.tensor(down.shape[0]),
+ }
+ return return_dict
+
+
+def loaded_lora():
+ if not shared.sd_loaded:
+ return ""
+ loaded = set()
+ if hasattr(shared.sd_model, 'unet'):
+ for _name, module in shared.sd_model.unet.named_modules():
+ current = getattr(module, "network_current_names", None)
+ if current is not None:
+ current = [item[0] for item in current]
+ loaded.update(current)
+ return list(loaded)
+
+
+def loaded_lora_str():
+ return ", ".join(loaded_lora())
+
+
+def make_meta(fn, maxrank, rank_ratio):
+ meta = {
+ "model_spec.sai_model_spec": "1.0.0",
+ "model_spec.title": os.path.splitext(os.path.basename(fn))[0],
+ "model_spec.author": "SD.Next",
+ "model_spec.implementation": "https://github.com/vladmandic/automatic",
+ "model_spec.date": datetime.datetime.now().astimezone().replace(microsecond=0).isoformat(),
+ "model_spec.base_model": shared.opts.sd_model_checkpoint,
+ "model_spec.dtype": str(devices.dtype),
+ "model_spec.base_lora": json.dumps(loaded_lora()),
+ "model_spec.config": f"maxrank={maxrank} rank_ratio={rank_ratio}",
+ }
+ if shared.sd_model_type == "sdxl":
+ meta["model_spec.architecture"] = "stable-diffusion-xl-v1-base/lora" # sai standard
+ meta["ss_base_model_version"] = "sdxl_base_v1-0" # kohya standard
+ elif shared.sd_model_type == "sd":
+ meta["model_spec.architecture"] = "stable-diffusion-v1/lora"
+ meta["ss_base_model_version"] = "sd_v1"
+ elif shared.sd_model_type == "f1":
+ meta["model_spec.architecture"] = "flux-1-dev/lora"
+ meta["ss_base_model_version"] = "flux1"
+ elif shared.sd_model_type == "sc":
+ meta["model_spec.architecture"] = "stable-cascade-v1-prior/lora"
+ return meta
+
+
+def make_lora(fn, maxrank, auto_rank, rank_ratio, modules, overwrite):
+ if not shared.sd_loaded or not shared.native:
+ msg = "LoRA extract: model not loaded"
+ shared.log.warning(msg)
+ yield msg
+ return
+ if loaded_lora() == "":
+ msg = "LoRA extract: no LoRA detected"
+ shared.log.warning(msg)
+ yield msg
+ return
+ if not fn:
+ msg = "LoRA extract: target filename required"
+ shared.log.warning(msg)
+ yield msg
+ return
+ t0 = time.time()
+ maxrank = int(maxrank)
+ rank_ratio = 1 if not auto_rank else rank_ratio
+ shared.log.debug(f'LoRA extract: modules={modules} maxrank={maxrank} auto={auto_rank} ratio={rank_ratio} fn="{fn}"')
+ shared.state.begin('LoRA extract')
+
+ with p.Progress(p.TextColumn('[cyan]LoRA extract'), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TextColumn('[cyan]{task.description}'), console=shared.console) as progress:
+
+ if 'te' in modules and getattr(shared.sd_model, 'text_encoder', None) is not None:
+ modules = shared.sd_model.text_encoder.named_modules()
+ task = progress.add_task(description="te1 decompose", total=len(list(modules)))
+ for name, module in shared.sd_model.text_encoder.named_modules():
+ progress.update(task, advance=1)
+ weights_backup = getattr(module, "network_weights_backup", None)
+ if weights_backup is None or getattr(module, "network_current_names", None) is None:
+ continue
+ prefix = "lora_te1_" if hasattr(shared.sd_model, 'text_encoder_2') else "lora_te_"
+ module.svdhandler = SVDHandler(maxrank, rank_ratio)
+ module.svdhandler.network_name = prefix + name.replace(".", "_")
+ with devices.inference_context():
+ module.svdhandler.decompose(module.weight, weights_backup)
+ progress.remove_task(task)
+ t1 = time.time()
+
+ if 'te' in modules and getattr(shared.sd_model, 'text_encoder_2', None) is not None:
+ modules = shared.sd_model.text_encoder_2.named_modules()
+ task = progress.add_task(description="te2 decompose", total=len(list(modules)))
+ for name, module in shared.sd_model.text_encoder_2.named_modules():
+ progress.update(task, advance=1)
+ weights_backup = getattr(module, "network_weights_backup", None)
+ if weights_backup is None or getattr(module, "network_current_names", None) is None:
+ continue
+ module.svdhandler = SVDHandler(maxrank, rank_ratio)
+ module.svdhandler.network_name = "lora_te2_" + name.replace(".", "_")
+ with devices.inference_context():
+ module.svdhandler.decompose(module.weight, weights_backup)
+ progress.remove_task(task)
+ t2 = time.time()
+
+ if 'unet' in modules and getattr(shared.sd_model, 'unet', None) is not None:
+ modules = shared.sd_model.unet.named_modules()
+ task = progress.add_task(description="unet decompose", total=len(list(modules)))
+ for name, module in shared.sd_model.unet.named_modules():
+ progress.update(task, advance=1)
+ weights_backup = getattr(module, "network_weights_backup", None)
+ if weights_backup is None or getattr(module, "network_current_names", None) is None:
+ continue
+ module.svdhandler = SVDHandler(maxrank, rank_ratio)
+ module.svdhandler.network_name = "lora_unet_" + name.replace(".", "_")
+ with devices.inference_context():
+ module.svdhandler.decompose(module.weight, weights_backup)
+ progress.remove_task(task)
+ t3 = time.time()
+
+ # TODO: lora make support quantized flux
+ # if 'te' in modules and getattr(shared.sd_model, 'transformer', None) is not None:
+ # for name, module in shared.sd_model.transformer.named_modules():
+ # if "norm" in name and "linear" not in name:
+ # continue
+ # weights_backup = getattr(module, "network_weights_backup", None)
+ # if weights_backup is None:
+ # continue
+ # module.svdhandler = SVDHandler()
+ # module.svdhandler.network_name = "lora_transformer_" + name.replace(".", "_")
+ # module.svdhandler.decompose(module.weight, weights_backup)
+ # module.svdhandler.findrank(rank, rank_ratio)
+
+ lora_state_dict = {}
+ for sub in ['text_encoder', 'text_encoder_2', 'unet', 'transformer']:
+ submodel = getattr(shared.sd_model, sub, None)
+ if submodel is not None:
+ modules = submodel.named_modules()
+ task = progress.add_task(description=f"{sub} exctract", total=len(list(modules)))
+ for _name, module in submodel.named_modules():
+ progress.update(task, advance=1)
+ if not hasattr(module, "svdhandler"):
+ continue
+ lora_state_dict.update(module.svdhandler.makeweights())
+ del module.svdhandler
+ progress.remove_task(task)
+ t4 = time.time()
+
+ if not os.path.isabs(fn):
+ fn = os.path.join(shared.cmd_opts.lora_dir, fn)
+ if not fn.endswith('.safetensors'):
+ fn += '.safetensors'
+ if os.path.exists(fn):
+ if overwrite:
+ os.remove(fn)
+ else:
+ msg = f'LoRA extract: fn="{fn}" file exists'
+ shared.log.warning(msg)
+ yield msg
+ return
+
+ shared.state.end()
+ meta = make_meta(fn, maxrank, rank_ratio)
+ shared.log.debug(f'LoRA metadata: {meta}')
+ try:
+ save_file(tensors=lora_state_dict, metadata=meta, filename=fn)
+ except Exception as e:
+ msg = f'LoRA extract error: fn="{fn}" {e}'
+ shared.log.error(msg)
+ yield msg
+ return
+ t5 = time.time()
+ shared.log.debug(f'LoRA extract: time={t5-t0:.2f} te1={t1-t0:.2f} te2={t2-t1:.2f} unet={t3-t2:.2f} save={t5-t4:.2f}')
+ keys = list(lora_state_dict.keys())
+ msg = f'LoRA extract: fn="{fn}" keys={len(keys)}'
+ shared.log.info(msg)
+ yield msg
+
+
+def create_ui():
+ def gr_show(visible=True):
+ return {"visible": visible, "__type__": "update"}
+
+ with gr.Tab(label="Extract LoRA"):
+ with gr.Row():
+ loaded = gr.Textbox(placeholder="Press refresh to query loaded LoRA", label="Loaded LoRA", interactive=False)
+ create_refresh_button(loaded, lambda: None, lambda: {'value': loaded_lora_str()}, "testid")
+ with gr.Group():
+ with gr.Row():
+ modules = gr.CheckboxGroup(label="Modules to extract", value=['unet'], choices=['te', 'unet'])
+ with gr.Row():
+ auto_rank = gr.Checkbox(value=False, label="Automatically determine rank")
+ rank_ratio = gr.Slider(label="Autorank ratio", value=1, minimum=0, maximum=1, step=0.05, visible=False)
+ rank = gr.Slider(label="Maximum rank", value=32, minimum=1, maximum=256)
+ with gr.Row():
+ filename = gr.Textbox(label="LoRA target filename")
+ overwrite = gr.Checkbox(value=False, label="Overwrite existing file")
+ with gr.Row():
+ extract = gr.Button(value="Extract LoRA", variant='primary')
+ status = gr.HTML(value="", show_label=False)
+
+ auto_rank.change(fn=lambda x: gr_show(x), inputs=[auto_rank], outputs=[rank_ratio])
+ extract.click(
+ fn=wrap_gradio_gpu_call(make_lora, extra_outputs=[]),
+ inputs=[filename, rank, auto_rank, rank_ratio, modules, overwrite],
+ outputs=[status]
+ )
diff --git a/modules/lora/lora_timers.py b/modules/lora/lora_timers.py
new file mode 100644
index 000000000..30c35a728
--- /dev/null
+++ b/modules/lora/lora_timers.py
@@ -0,0 +1,38 @@
+class Timer():
+ list: float = 0
+ load: float = 0
+ backup: float = 0
+ calc: float = 0
+ apply: float = 0
+ move: float = 0
+ restore: float = 0
+ activate: float = 0
+ deactivate: float = 0
+
+ @property
+ def total(self):
+ return round(self.activate + self.deactivate, 2)
+
+ @property
+ def summary(self):
+ t = {}
+ for k, v in self.__dict__.items():
+ if v > 0.1:
+ t[k] = round(v, 2)
+ return t
+
+ def clear(self, complete: bool = False):
+ self.backup = 0
+ self.calc = 0
+ self.apply = 0
+ self.move = 0
+ self.restore = 0
+ if complete:
+ self.activate = 0
+ self.deactivate = 0
+
+ def add(self, name, t):
+ self.__dict__[name] += t
+
+ def __str__(self):
+ return f'{self.__class__.__name__}({self.summary})'
diff --git a/modules/lora/lyco_helpers.py b/modules/lora/lyco_helpers.py
new file mode 100644
index 000000000..ac4f2419f
--- /dev/null
+++ b/modules/lora/lyco_helpers.py
@@ -0,0 +1,66 @@
+import torch
+
+
+def make_weight_cp(t, wa, wb):
+ temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
+ return torch.einsum('i j k l, i r -> r j k l', temp, wa)
+
+
+def rebuild_conventional(up, down, shape, dyn_dim=None):
+ up = up.reshape(up.size(0), -1)
+ down = down.reshape(down.size(0), -1)
+ if dyn_dim is not None:
+ up = up[:, :dyn_dim]
+ down = down[:dyn_dim, :]
+ return (up @ down).reshape(shape).to(up.dtype)
+
+
+def rebuild_cp_decomposition(up, down, mid):
+ up = up.reshape(up.size(0), -1)
+ down = down.reshape(down.size(0), -1)
+ return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down).to(up.dtype)
+
+
+# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
+def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
+ """
+ return a tuple of two value of input dimension decomposed by the number closest to factor
+ second value is higher or equal than first value.
+
+ In LoRA with Kroneckor Product, first value is a value for weight scale.
+ secon value is a value for weight.
+
+ Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
+
+ examples
+ factor
+ -1 2 4 8 16 ...
+ 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
+ 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
+ 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
+ 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
+ 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
+ 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
+ """
+
+ if factor > 0 and (dimension % factor) == 0:
+ m = factor
+ n = dimension // factor
+ if m > n:
+ n, m = m, n
+ return m, n
+ if factor < 0:
+ factor = dimension
+ m, n = 1, dimension
+ length = m + n
+ while m length or new_m>factor:
+ break
+ m, n = new_m, new_n
+ if m > n:
+ n, m = m, n
+ return m, n
diff --git a/modules/lora/network.py b/modules/lora/network.py
new file mode 100644
index 000000000..97feb76f1
--- /dev/null
+++ b/modules/lora/network.py
@@ -0,0 +1,188 @@
+import os
+import enum
+from typing import Union
+from collections import namedtuple
+from modules import sd_models, hashes, shared
+
+
+NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
+metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
+
+
+class SdVersion(enum.Enum):
+ Unknown = 1
+ SD1 = 2
+ SD2 = 3
+ SD3 = 3
+ SDXL = 4
+ SC = 5
+ F1 = 6
+
+
+class NetworkOnDisk:
+ def __init__(self, name, filename):
+ self.shorthash = None
+ self.hash = None
+ self.name = name
+ self.filename = filename
+ if filename.startswith(shared.cmd_opts.lora_dir):
+ self.fullname = os.path.splitext(filename[len(shared.cmd_opts.lora_dir):].strip("/"))[0]
+ else:
+ self.fullname = name
+ self.metadata = {}
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
+ if self.is_safetensors:
+ self.metadata = sd_models.read_metadata_from_safetensors(filename)
+ if self.metadata:
+ m = {}
+ for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
+ m[k] = v
+ self.metadata = m
+ self.alias = self.metadata.get('ss_output_name', self.name)
+ sha256 = hashes.sha256_from_cache(self.filename, "lora/" + self.name) or hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=True) or self.metadata.get('sshs_model_hash')
+ self.set_hash(sha256)
+ self.sd_version = self.detect_version()
+
+ def detect_version(self):
+ base = str(self.metadata.get('ss_base_model_version', "")).lower()
+ arch = str(self.metadata.get('modelspec.architecture', "")).lower()
+ if base.startswith("sd_v1"):
+ return 'sd1'
+ if base.startswith("sdxl"):
+ return 'xl'
+ if base.startswith("stable_cascade"):
+ return 'sc'
+ if base.startswith("sd3"):
+ return 'sd3'
+ if base.startswith("flux"):
+ return 'f1'
+
+ if arch.startswith("stable-diffusion-v1"):
+ return 'sd1'
+ if arch.startswith("stable-diffusion-xl"):
+ return 'xl'
+ if arch.startswith("stable-cascade"):
+ return 'sc'
+ if arch.startswith("flux"):
+ return 'f1'
+
+ if "v1-5" in str(self.metadata.get('ss_sd_model_name', "")):
+ return 'sd1'
+ if str(self.metadata.get('ss_v2', "")) == "True":
+ return 'sd2'
+ if 'flux' in self.name.lower():
+ return 'f1'
+ if 'xl' in self.name.lower():
+ return 'xl'
+
+ return ''
+
+ def set_hash(self, v):
+ self.hash = v or ''
+ self.shorthash = self.hash[0:8]
+
+ def read_hash(self):
+ if not self.hash:
+ self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
+
+ def get_alias(self):
+ import modules.lora.networks as networks
+ return self.name if shared.opts.lora_preferred_name == "filename" or self.alias.lower() in networks.forbidden_network_aliases else self.alias
+
+
+class Network: # LoraModule
+ def __init__(self, name, network_on_disk: NetworkOnDisk):
+ self.name = name
+ self.network_on_disk = network_on_disk
+ self.te_multiplier = 1.0
+ self.unet_multiplier = [1.0] * 3
+ self.dyn_dim = None
+ self.modules = {}
+ self.bundle_embeddings = {}
+ self.mtime = None
+ self.mentioned_name = None
+ self.tags = None
+ """the text that was used to add the network to prompt - can be either name or an alias"""
+
+
+class ModuleType:
+ def create_module(self, net: Network, weights: NetworkWeights) -> Union[Network, None]: # pylint: disable=W0613
+ return None
+
+
+class NetworkModule:
+ def __init__(self, net: Network, weights: NetworkWeights):
+ self.network = net
+ self.network_key = weights.network_key
+ self.sd_key = weights.sd_key
+ self.sd_module = weights.sd_module
+ if hasattr(self.sd_module, 'weight'):
+ self.shape = self.sd_module.weight.shape
+ self.dim = None
+ self.bias = weights.w.get("bias")
+ self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
+ self.scale = weights.w["scale"].item() if "scale" in weights.w else None
+ self.dora_scale = weights.w.get("dora_scale", None)
+ self.dora_norm_dims = len(self.shape) - 1
+
+ def multiplier(self):
+ unet_multiplier = 3 * [self.network.unet_multiplier] if not isinstance(self.network.unet_multiplier, list) else self.network.unet_multiplier
+ if 'transformer' in self.sd_key[:20]:
+ return self.network.te_multiplier
+ if "down_blocks" in self.sd_key:
+ return unet_multiplier[0]
+ if "mid_block" in self.sd_key:
+ return unet_multiplier[1]
+ if "up_blocks" in self.sd_key:
+ return unet_multiplier[2]
+ else:
+ return unet_multiplier[0]
+
+ def calc_scale(self):
+ if self.scale is not None:
+ return self.scale
+ if self.dim is not None and self.alpha is not None:
+ return self.alpha / self.dim
+ return 1.0
+
+ def apply_weight_decompose(self, updown, orig_weight):
+ # Match the device/dtype
+ orig_weight = orig_weight.to(updown.dtype)
+ dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
+ updown = updown.to(orig_weight.device)
+
+ merged_scale1 = updown + orig_weight
+ merged_scale1_norm = (
+ merged_scale1.transpose(0, 1)
+ .reshape(merged_scale1.shape[1], -1)
+ .norm(dim=1, keepdim=True)
+ .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
+ .transpose(0, 1)
+ )
+
+ dora_merged = (
+ merged_scale1 * (dora_scale / merged_scale1_norm)
+ )
+ final_updown = dora_merged - orig_weight
+ return final_updown
+
+ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
+ if self.bias is not None:
+ updown = updown.reshape(self.bias.shape)
+ updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = updown.reshape(output_shape)
+ if len(output_shape) == 4:
+ updown = updown.reshape(output_shape)
+ if orig_weight.size().numel() == updown.size().numel():
+ updown = updown.reshape(orig_weight.shape)
+ if ex_bias is not None:
+ ex_bias = ex_bias * self.multiplier()
+ if self.dora_scale is not None:
+ updown = self.apply_weight_decompose(updown, orig_weight)
+ return updown * self.calc_scale() * self.multiplier(), ex_bias
+
+ def calc_updown(self, target):
+ raise NotImplementedError
+
+ def forward(self, x, y):
+ raise NotImplementedError
diff --git a/modules/lora/network_full.py b/modules/lora/network_full.py
new file mode 100644
index 000000000..5eb0b2e4e
--- /dev/null
+++ b/modules/lora/network_full.py
@@ -0,0 +1,26 @@
+import modules.lora.network as network
+
+
+class ModuleTypeFull(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["diff"]):
+ return NetworkModuleFull(net, weights)
+ return None
+
+
+class NetworkModuleFull(network.NetworkModule): # pylint: disable=abstract-method
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ self.weight = weights.w.get("diff")
+ self.ex_bias = weights.w.get("diff_b")
+
+ def calc_updown(self, target):
+ output_shape = self.weight.shape
+ updown = self.weight.to(target.device, dtype=target.dtype)
+ if self.ex_bias is not None:
+ ex_bias = self.ex_bias.to(target.device, dtype=target.dtype)
+ else:
+ ex_bias = None
+
+ return self.finalize_updown(updown, target, output_shape, ex_bias)
diff --git a/modules/lora/network_glora.py b/modules/lora/network_glora.py
new file mode 100644
index 000000000..ffcb25986
--- /dev/null
+++ b/modules/lora/network_glora.py
@@ -0,0 +1,30 @@
+import modules.lora.network as network
+
+
+class ModuleTypeGLora(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
+ return NetworkModuleGLora(net, weights)
+ return None
+
+# adapted from https://github.com/KohakuBlueleaf/LyCORIS
+class NetworkModuleGLora(network.NetworkModule): # pylint: disable=abstract-method
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ if hasattr(self.sd_module, 'weight'):
+ self.shape = self.sd_module.weight.shape
+
+ self.w1a = weights.w["a1.weight"]
+ self.w1b = weights.w["b1.weight"]
+ self.w2a = weights.w["a2.weight"]
+ self.w2b = weights.w["b2.weight"]
+
+ def calc_updown(self, target): # pylint: disable=arguments-differ
+ w1a = self.w1a.to(target.device, dtype=target.dtype)
+ w1b = self.w1b.to(target.device, dtype=target.dtype)
+ w2a = self.w2a.to(target.device, dtype=target.dtype)
+ w2b = self.w2b.to(target.device, dtype=target.dtype)
+ output_shape = [w1a.size(0), w1b.size(1)]
+ updown = (w2b @ w1b) + ((target @ w2a) @ w1a)
+ return self.finalize_updown(updown, target, output_shape)
diff --git a/modules/lora/network_hada.py b/modules/lora/network_hada.py
new file mode 100644
index 000000000..6fc142b3b
--- /dev/null
+++ b/modules/lora/network_hada.py
@@ -0,0 +1,46 @@
+import modules.lora.lyco_helpers as lyco_helpers
+import modules.lora.network as network
+
+
+class ModuleTypeHada(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
+ return NetworkModuleHada(net, weights)
+ return None
+
+
+class NetworkModuleHada(network.NetworkModule): # pylint: disable=abstract-method
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+ if hasattr(self.sd_module, 'weight'):
+ self.shape = self.sd_module.weight.shape
+ self.w1a = weights.w["hada_w1_a"]
+ self.w1b = weights.w["hada_w1_b"]
+ self.dim = self.w1b.shape[0]
+ self.w2a = weights.w["hada_w2_a"]
+ self.w2b = weights.w["hada_w2_b"]
+ self.t1 = weights.w.get("hada_t1")
+ self.t2 = weights.w.get("hada_t2")
+
+ def calc_updown(self, target):
+ w1a = self.w1a.to(target.device, dtype=target.dtype)
+ w1b = self.w1b.to(target.device, dtype=target.dtype)
+ w2a = self.w2a.to(target.device, dtype=target.dtype)
+ w2b = self.w2b.to(target.device, dtype=target.dtype)
+ output_shape = [w1a.size(0), w1b.size(1)]
+ if self.t1 is not None:
+ output_shape = [w1a.size(1), w1b.size(1)]
+ t1 = self.t1.to(target.device, dtype=target.dtype)
+ updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
+ output_shape += t1.shape[2:]
+ else:
+ if len(w1b.shape) == 4:
+ output_shape += w1b.shape[2:]
+ updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
+ if self.t2 is not None:
+ t2 = self.t2.to(target.device, dtype=target.dtype)
+ updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
+ else:
+ updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
+ updown = updown1 * updown2
+ return self.finalize_updown(updown, target, output_shape)
diff --git a/modules/lora/network_ia3.py b/modules/lora/network_ia3.py
new file mode 100644
index 000000000..479e42526
--- /dev/null
+++ b/modules/lora/network_ia3.py
@@ -0,0 +1,24 @@
+import modules.lora.network as network
+
+class ModuleTypeIa3(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["weight"]):
+ return NetworkModuleIa3(net, weights)
+ return None
+
+
+class NetworkModuleIa3(network.NetworkModule): # pylint: disable=abstract-method
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+ self.w = weights.w["weight"]
+ self.on_input = weights.w["on_input"].item()
+
+ def calc_updown(self, target):
+ w = self.w.to(target.device, dtype=target.dtype)
+ output_shape = [w.size(0), target.size(1)]
+ if self.on_input:
+ output_shape.reverse()
+ else:
+ w = w.reshape(-1, 1)
+ updown = target * w
+ return self.finalize_updown(updown, target, output_shape)
diff --git a/modules/lora/network_lokr.py b/modules/lora/network_lokr.py
new file mode 100644
index 000000000..877d4005b
--- /dev/null
+++ b/modules/lora/network_lokr.py
@@ -0,0 +1,57 @@
+import torch
+import modules.lora.lyco_helpers as lyco_helpers
+import modules.lora.network as network
+
+
+class ModuleTypeLokr(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
+ has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
+ if has_1 and has_2:
+ return NetworkModuleLokr(net, weights)
+ return None
+
+
+def make_kron(orig_shape, w1, w2):
+ if len(w2.shape) == 4:
+ w1 = w1.unsqueeze(2).unsqueeze(2)
+ w2 = w2.contiguous()
+ return torch.kron(w1, w2).reshape(orig_shape)
+
+
+class NetworkModuleLokr(network.NetworkModule): # pylint: disable=abstract-method
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+ self.w1 = weights.w.get("lokr_w1")
+ self.w1a = weights.w.get("lokr_w1_a")
+ self.w1b = weights.w.get("lokr_w1_b")
+ self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
+ self.w2 = weights.w.get("lokr_w2")
+ self.w2a = weights.w.get("lokr_w2_a")
+ self.w2b = weights.w.get("lokr_w2_b")
+ self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
+ self.t2 = weights.w.get("lokr_t2")
+
+ def calc_updown(self, target):
+ if self.w1 is not None:
+ w1 = self.w1.to(target.device, dtype=target.dtype)
+ else:
+ w1a = self.w1a.to(target.device, dtype=target.dtype)
+ w1b = self.w1b.to(target.device, dtype=target.dtype)
+ w1 = w1a @ w1b
+ if self.w2 is not None:
+ w2 = self.w2.to(target.device, dtype=target.dtype)
+ elif self.t2 is None:
+ w2a = self.w2a.to(target.device, dtype=target.dtype)
+ w2b = self.w2b.to(target.device, dtype=target.dtype)
+ w2 = w2a @ w2b
+ else:
+ t2 = self.t2.to(target.device, dtype=target.dtype)
+ w2a = self.w2a.to(target.device, dtype=target.dtype)
+ w2b = self.w2b.to(target.device, dtype=target.dtype)
+ w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
+ output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
+ if len(target.shape) == 4:
+ output_shape = target.shape
+ updown = make_kron(output_shape, w1, w2)
+ return self.finalize_updown(updown, target, output_shape)
diff --git a/modules/lora/network_lora.py b/modules/lora/network_lora.py
new file mode 100644
index 000000000..3604e059d
--- /dev/null
+++ b/modules/lora/network_lora.py
@@ -0,0 +1,75 @@
+import torch
+import diffusers.models.lora as diffusers_lora
+import modules.lora.lyco_helpers as lyco_helpers
+import modules.lora.network as network
+from modules import devices
+
+
+class ModuleTypeLora(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
+ return NetworkModuleLora(net, weights)
+ return None
+
+
+class NetworkModuleLora(network.NetworkModule):
+
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+ self.up_model = self.create_module(weights.w, "lora_up.weight")
+ self.down_model = self.create_module(weights.w, "lora_down.weight")
+ self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
+ self.dim = weights.w["lora_down.weight"].shape[0]
+
+ def create_module(self, weights, key, none_ok=False):
+ weight = weights.get(key)
+ if weight is None and none_ok:
+ return None
+ linear_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear]
+ is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ in {"NNCFLinear", "QLinear", "Linear4bit"}
+ is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv] or self.sd_module.__class__.__name__ in {"NNCFConv2d", "QConv2d"}
+ if is_linear:
+ weight = weight.reshape(weight.shape[0], -1)
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
+ elif is_conv and (key == "lora_down.weight" or key == "dyn_up"):
+ if len(weight.shape) == 2:
+ weight = weight.reshape(weight.shape[0], -1, 1, 1)
+ if weight.shape[2] != 1 or weight.shape[3] != 1:
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
+ else:
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ elif is_conv and key == "lora_mid.weight":
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
+ elif is_conv and (key == "lora_up.weight" or key == "dyn_down"):
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ else:
+ raise AssertionError(f'Lora unsupported: layer={self.network_key} type={type(self.sd_module).__name__}')
+ with torch.no_grad():
+ if weight.shape != module.weight.shape:
+ weight = weight.reshape(module.weight.shape)
+ module.weight.copy_(weight)
+ module.weight.requires_grad_(False)
+ return module
+
+ def calc_updown(self, target): # pylint: disable=W0237
+ target_dtype = target.dtype if target.dtype != torch.uint8 else self.up_model.weight.dtype
+ up = self.up_model.weight.to(target.device, dtype=target_dtype)
+ down = self.down_model.weight.to(target.device, dtype=target_dtype)
+ output_shape = [up.size(0), down.size(1)]
+ if self.mid_model is not None:
+ # cp-decomposition
+ mid = self.mid_model.weight.to(target.device, dtype=target_dtype)
+ updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
+ output_shape += mid.shape[2:]
+ else:
+ if len(down.shape) == 4:
+ output_shape += down.shape[2:]
+ updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
+ return self.finalize_updown(updown, target, output_shape)
+
+ def forward(self, x, y):
+ self.up_model.to(device=devices.device)
+ self.down_model.to(device=devices.device)
+ if hasattr(y, "scale"):
+ return y(scale=1) + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
+ return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
diff --git a/modules/lora/network_norm.py b/modules/lora/network_norm.py
new file mode 100644
index 000000000..5d059e92e
--- /dev/null
+++ b/modules/lora/network_norm.py
@@ -0,0 +1,24 @@
+import modules.lora.network as network
+
+
+class ModuleTypeNorm(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["w_norm", "b_norm"]):
+ return NetworkModuleNorm(net, weights)
+ return None
+
+
+class NetworkModuleNorm(network.NetworkModule): # pylint: disable=abstract-method
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+ self.w_norm = weights.w.get("w_norm")
+ self.b_norm = weights.w.get("b_norm")
+
+ def calc_updown(self, target):
+ output_shape = self.w_norm.shape
+ updown = self.w_norm.to(target.device, dtype=target.dtype)
+ if self.b_norm is not None:
+ ex_bias = self.b_norm.to(target.device, dtype=target.dtype)
+ else:
+ ex_bias = None
+ return self.finalize_updown(updown, target, output_shape, ex_bias)
diff --git a/modules/lora/network_oft.py b/modules/lora/network_oft.py
new file mode 100644
index 000000000..e2e61ad45
--- /dev/null
+++ b/modules/lora/network_oft.py
@@ -0,0 +1,82 @@
+import torch
+from einops import rearrange
+import modules.lora.network as network
+from modules.lora.lyco_helpers import factorization
+
+
+class ModuleTypeOFT(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
+ return NetworkModuleOFT(net, weights)
+ return None
+
+
+# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
+# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
+class NetworkModuleOFT(network.NetworkModule): # pylint: disable=abstract-method
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+ self.lin_module = None
+ self.org_module: list[torch.Module] = [self.sd_module]
+ self.scale = 1.0
+
+ # kohya-ss
+ if "oft_blocks" in weights.w.keys():
+ self.is_kohya = True
+ self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
+ self.alpha = weights.w["alpha"] # alpha is constraint
+ self.dim = self.oft_blocks.shape[0] # lora dim
+ # LyCORIS
+ elif "oft_diag" in weights.w.keys():
+ self.is_kohya = False
+ self.oft_blocks = weights.w["oft_diag"]
+ # self.alpha is unused
+ self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
+
+ is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
+ is_conv = type(self.sd_module) in [torch.nn.Conv2d]
+ is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
+
+ if is_linear:
+ self.out_dim = self.sd_module.out_features
+ elif is_conv:
+ self.out_dim = self.sd_module.out_channels
+ elif is_other_linear:
+ self.out_dim = self.sd_module.embed_dim
+
+ if self.is_kohya:
+ self.constraint = self.alpha * self.out_dim
+ self.num_blocks = self.dim
+ self.block_size = self.out_dim // self.dim
+ else:
+ self.constraint = None
+ self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
+
+ def calc_updown(self, target):
+ oft_blocks = self.oft_blocks.to(target.device, dtype=target.dtype)
+ eye = torch.eye(self.block_size, device=target.device)
+ constraint = self.constraint.to(target.device)
+
+ if self.is_kohya:
+ block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
+ norm_Q = torch.norm(block_Q.flatten()).to(target.device)
+ new_norm_Q = torch.clamp(norm_Q, max=constraint)
+ block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
+ mat1 = eye + block_Q
+ mat2 = (eye - block_Q).float().inverse()
+ oft_blocks = torch.matmul(mat1, mat2)
+
+ R = oft_blocks.to(target.device, dtype=target.dtype)
+
+ # This errors out for MultiheadAttention, might need to be handled up-stream
+ merged_weight = rearrange(target, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
+ merged_weight = torch.einsum(
+ 'k n m, k n ... -> k m ...',
+ R,
+ merged_weight
+ )
+ merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
+
+ updown = merged_weight.to(target.device, dtype=target.dtype) - target
+ output_shape = target.shape
+ return self.finalize_updown(updown, target, output_shape)
diff --git a/modules/lora/network_overrides.py b/modules/lora/network_overrides.py
new file mode 100644
index 000000000..5334f3c1b
--- /dev/null
+++ b/modules/lora/network_overrides.py
@@ -0,0 +1,49 @@
+from modules import shared
+
+
+maybe_diffusers = [ # forced if lora_maybe_diffusers is enabled
+ 'aaebf6360f7d', # sd15-lcm
+ '3d18b05e4f56', # sdxl-lcm
+ 'b71dcb732467', # sdxl-tcd
+ '813ea5fb1c67', # sdxl-turbo
+ # not really needed, but just in case
+ '5a48ac366664', # hyper-sd15-1step
+ 'ee0ff23dcc42', # hyper-sd15-2step
+ 'e476eb1da5df', # hyper-sd15-4step
+ 'ecb844c3f3b0', # hyper-sd15-8step
+ '1ab289133ebb', # hyper-sd15-8step-cfg
+ '4f494295edb1', # hyper-sdxl-8step
+ 'ca14a8c621f8', # hyper-sdxl-8step-cfg
+ '1c88f7295856', # hyper-sdxl-4step
+ 'fdd5dcd1d88a', # hyper-sdxl-2step
+ '8cca3706050b', # hyper-sdxl-1step
+]
+
+force_diffusers = [ # forced always
+ '816d0eed49fd', # flash-sdxl
+ 'c2ec22757b46', # flash-sd15
+]
+
+force_models = [ # forced always
+ 'sc',
+ # 'sd3',
+ 'kandinsky',
+ 'hunyuandit',
+ 'auraflow',
+]
+
+force_classes = [ # forced always
+]
+
+
+def check_override(shorthash=''):
+ force = False
+ force = force or (shared.sd_model_type in force_models)
+ force = force or (shared.sd_model.__class__.__name__ in force_classes)
+ if len(shorthash) < 4:
+ return force
+ force = force or (any(x.startswith(shorthash) for x in maybe_diffusers) if shared.opts.lora_maybe_diffusers else False)
+ force = force or any(x.startswith(shorthash) for x in force_diffusers)
+ if force and shared.opts.lora_maybe_diffusers:
+ shared.log.debug('LoRA override: force diffusers')
+ return force
diff --git a/modules/lora/networks.py b/modules/lora/networks.py
new file mode 100644
index 000000000..fc90ddd2d
--- /dev/null
+++ b/modules/lora/networks.py
@@ -0,0 +1,601 @@
+from typing import Union, List
+from contextlib import nullcontext
+import os
+import re
+import time
+import concurrent
+import torch
+import diffusers.models.lora
+import rich.progress as rp
+
+from modules.lora import lora_timers, network, lora_convert, network_overrides
+from modules.lora import network_lora, network_hada, network_ia3, network_oft, network_lokr, network_full, network_norm, network_glora
+from modules.lora.extra_networks_lora import ExtraNetworkLora
+from modules import shared, devices, sd_models, sd_models_compile, errors, files_cache, model_quant
+
+
+debug = os.environ.get('SD_LORA_DEBUG', None) is not None
+extra_network_lora = ExtraNetworkLora()
+available_networks = {}
+available_network_aliases = {}
+loaded_networks: List[network.Network] = []
+applied_layers: list[str] = []
+bnb = None
+lora_cache = {}
+diffuser_loaded = []
+diffuser_scales = []
+available_network_hash_lookup = {}
+forbidden_network_aliases = {}
+re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
+timer = lora_timers.Timer()
+module_types = [
+ network_lora.ModuleTypeLora(),
+ network_hada.ModuleTypeHada(),
+ network_ia3.ModuleTypeIa3(),
+ network_oft.ModuleTypeOFT(),
+ network_lokr.ModuleTypeLokr(),
+ network_full.ModuleTypeFull(),
+ network_norm.ModuleTypeNorm(),
+ network_glora.ModuleTypeGLora(),
+]
+
+# section: load networks from disk
+
+def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_default_multiplier) -> Union[network.Network, None]:
+ t0 = time.time()
+ name = name.replace(".", "_")
+ shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" detected={network_on_disk.sd_version} method=diffusers scale={lora_scale} fuse={shared.opts.lora_fuse_diffusers}')
+ if not shared.native:
+ return None
+ if not hasattr(shared.sd_model, 'load_lora_weights'):
+ shared.log.error(f'Load network: type=LoRA class={shared.sd_model.__class__} does not implement load lora')
+ return None
+ try:
+ shared.sd_model.load_lora_weights(network_on_disk.filename, adapter_name=name)
+ except Exception as e:
+ if 'already in use' in str(e):
+ pass
+ else:
+ if 'The following keys have not been correctly renamed' in str(e):
+ shared.log.error(f'Load network: type=LoRA name="{name}" diffusers unsupported format')
+ else:
+ shared.log.error(f'Load network: type=LoRA name="{name}" {e}')
+ if debug:
+ errors.display(e, "LoRA")
+ return None
+ if name not in diffuser_loaded:
+ diffuser_loaded.append(name)
+ diffuser_scales.append(lora_scale)
+ net = network.Network(name, network_on_disk)
+ net.mtime = os.path.getmtime(network_on_disk.filename)
+ timer.activate += time.time() - t0
+ return net
+
+
+def load_safetensors(name, network_on_disk) -> Union[network.Network, None]:
+ if not shared.sd_loaded:
+ return None
+
+ cached = lora_cache.get(name, None)
+ if debug:
+ shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" type=lora {"cached" if cached else ""}')
+ if cached is not None:
+ return cached
+ net = network.Network(name, network_on_disk)
+ net.mtime = os.path.getmtime(network_on_disk.filename)
+ sd = sd_models.read_state_dict(network_on_disk.filename, what='network')
+ if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict
+ sd = lora_convert._convert_kohya_flux_lora_to_diffusers(sd) or sd # pylint: disable=protected-access
+ if shared.sd_model_type == 'sd3': # if kohya flux lora, convert state_dict
+ try:
+ sd = lora_convert._convert_kohya_sd3_lora_to_diffusers(sd) or sd # pylint: disable=protected-access
+ except ValueError: # EAFP for diffusers PEFT keys
+ pass
+ lora_convert.assign_network_names_to_compvis_modules(shared.sd_model)
+ keys_failed_to_match = {}
+ matched_networks = {}
+ bundle_embeddings = {}
+ convert = lora_convert.KeyConvert()
+ for key_network, weight in sd.items():
+ parts = key_network.split('.')
+ if parts[0] == "bundle_emb":
+ emb_name, vec_name = parts[1], key_network.split(".", 2)[-1]
+ emb_dict = bundle_embeddings.get(emb_name, {})
+ emb_dict[vec_name] = weight
+ bundle_embeddings[emb_name] = emb_dict
+ continue
+ if len(parts) > 5: # messy handler for diffusers peft lora
+ key_network_without_network_parts = '_'.join(parts[:-2])
+ if not key_network_without_network_parts.startswith('lora_'):
+ key_network_without_network_parts = 'lora_' + key_network_without_network_parts
+ network_part = '.'.join(parts[-2:]).replace('lora_A', 'lora_down').replace('lora_B', 'lora_up')
+ else:
+ key_network_without_network_parts, network_part = key_network.split(".", 1)
+ key, sd_module = convert(key_network_without_network_parts)
+ if sd_module is None:
+ keys_failed_to_match[key_network] = key
+ continue
+ if key not in matched_networks:
+ matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
+ matched_networks[key].w[network_part] = weight
+ network_types = []
+ for key, weights in matched_networks.items():
+ net_module = None
+ for nettype in module_types:
+ net_module = nettype.create_module(net, weights)
+ if net_module is not None:
+ network_types.append(nettype.__class__.__name__)
+ break
+ if net_module is None:
+ shared.log.error(f'LoRA unhandled: name={name} key={key} weights={weights.w.keys()}')
+ else:
+ net.modules[key] = net_module
+ if len(keys_failed_to_match) > 0:
+ shared.log.warning(f'LoRA name="{name}" type={set(network_types)} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}')
+ if debug:
+ shared.log.debug(f'LoRA name="{name}" unmatched={keys_failed_to_match}')
+ else:
+ shared.log.debug(f'LoRA name="{name}" type={set(network_types)} keys={len(matched_networks)} direct={shared.opts.lora_fuse_diffusers}')
+ if len(matched_networks) == 0:
+ return None
+ lora_cache[name] = net
+ net.bundle_embeddings = bundle_embeddings
+ return net
+
+
+def maybe_recompile_model(names, te_multipliers):
+ recompile_model = False
+ if shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled:
+ if len(names) == len(shared.compiled_model_state.lora_model):
+ for i, name in enumerate(names):
+ if shared.compiled_model_state.lora_model[
+ i] != f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}":
+ recompile_model = True
+ shared.compiled_model_state.lora_model = []
+ break
+ if not recompile_model:
+ if len(loaded_networks) > 0 and debug:
+ shared.log.debug('Model Compile: Skipping LoRa loading')
+ return recompile_model
+ else:
+ recompile_model = True
+ shared.compiled_model_state.lora_model = []
+ if recompile_model:
+ backup_cuda_compile = shared.opts.cuda_compile
+ sd_models.unload_model_weights(op='model')
+ shared.opts.cuda_compile = []
+ sd_models.reload_model_weights(op='model')
+ shared.opts.cuda_compile = backup_cuda_compile
+ return recompile_model
+
+
+def list_available_networks():
+ t0 = time.time()
+ available_networks.clear()
+ available_network_aliases.clear()
+ forbidden_network_aliases.clear()
+ available_network_hash_lookup.clear()
+ forbidden_network_aliases.update({"none": 1, "Addams": 1})
+ if not os.path.exists(shared.cmd_opts.lora_dir):
+ shared.log.warning(f'LoRA directory not found: path="{shared.cmd_opts.lora_dir}"')
+
+ def add_network(filename):
+ if not os.path.isfile(filename):
+ return
+ name = os.path.splitext(os.path.basename(filename))[0]
+ name = name.replace('.', '_')
+ try:
+ entry = network.NetworkOnDisk(name, filename)
+ available_networks[entry.name] = entry
+ if entry.alias in available_network_aliases:
+ forbidden_network_aliases[entry.alias.lower()] = 1
+ if shared.opts.lora_preferred_name == 'filename':
+ available_network_aliases[entry.name] = entry
+ else:
+ available_network_aliases[entry.alias] = entry
+ if entry.shorthash:
+ available_network_hash_lookup[entry.shorthash] = entry
+ except OSError as e: # should catch FileNotFoundError and PermissionError etc.
+ shared.log.error(f'LoRA: filename="{filename}" {e}')
+
+ candidates = list(files_cache.list_files(shared.cmd_opts.lora_dir, ext_filter=[".pt", ".ckpt", ".safetensors"]))
+ with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
+ for fn in candidates:
+ executor.submit(add_network, fn)
+ t1 = time.time()
+ timer.list = t1 - t0
+ shared.log.info(f'Available LoRAs: path="{shared.cmd_opts.lora_dir}" items={len(available_networks)} folders={len(forbidden_network_aliases)} time={t1 - t0:.2f}')
+
+
+def network_download(name):
+ from huggingface_hub import hf_hub_download
+ if os.path.exists(name):
+ return network.NetworkOnDisk(name, name)
+ parts = name.split('/')
+ if len(parts) >= 5 and parts[1] == 'huggingface.co':
+ repo_id = f'{parts[2]}/{parts[3]}'
+ filename = '/'.join(parts[4:])
+ fn = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=shared.opts.hfcache_dir)
+ return network.NetworkOnDisk(name, fn)
+ return None
+
+
+def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
+ networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names]
+ if any(x is None for x in networks_on_disk):
+ list_available_networks()
+ networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names]
+ for i in range(len(names)):
+ if names[i].startswith('/'):
+ networks_on_disk[i] = network_download(names[i])
+ failed_to_load_networks = []
+ recompile_model = maybe_recompile_model(names, te_multipliers)
+
+ loaded_networks.clear()
+ diffuser_loaded.clear()
+ diffuser_scales.clear()
+ t0 = time.time()
+
+ for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
+ net = None
+ if network_on_disk is not None:
+ shorthash = getattr(network_on_disk, 'shorthash', '').lower()
+ if debug:
+ shared.log.debug(f'Load network: type=LoRA name="{name}" file="{network_on_disk.filename}" hash="{shorthash}"')
+ try:
+ if recompile_model:
+ shared.compiled_model_state.lora_model.append(f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}")
+ if shared.opts.lora_force_diffusers or network_overrides.check_override(shorthash): # OpenVINO only works with Diffusers LoRa loading
+ net = load_diffusers(name, network_on_disk, lora_scale=te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier)
+ else:
+ net = load_safetensors(name, network_on_disk)
+ if net is not None:
+ net.mentioned_name = name
+ network_on_disk.read_hash()
+ except Exception as e:
+ shared.log.error(f'Load network: type=LoRA file="{network_on_disk.filename}" {e}')
+ if debug:
+ errors.display(e, 'LoRA')
+ continue
+ if net is None:
+ failed_to_load_networks.append(name)
+ shared.log.error(f'Load network: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} failed')
+ continue
+ shared.sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings)
+ net.te_multiplier = te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier
+ net.unet_multiplier = unet_multipliers[i] if unet_multipliers else shared.opts.extra_networks_default_multiplier
+ net.dyn_dim = dyn_dims[i] if dyn_dims else shared.opts.extra_networks_default_multiplier
+ loaded_networks.append(net)
+
+ while len(lora_cache) > shared.opts.lora_in_memory_limit:
+ name = next(iter(lora_cache))
+ lora_cache.pop(name, None)
+
+ if len(diffuser_loaded) > 0:
+ shared.log.debug(f'Load network: type=LoRA loaded={diffuser_loaded} available={shared.sd_model.get_list_adapters()} active={shared.sd_model.get_active_adapters()} scales={diffuser_scales}')
+ try:
+ t0 = time.time()
+ shared.sd_model.set_adapters(adapter_names=diffuser_loaded, adapter_weights=diffuser_scales)
+ if shared.opts.lora_fuse_diffusers:
+ shared.sd_model.fuse_lora(adapter_names=diffuser_loaded, lora_scale=1.0, fuse_unet=True, fuse_text_encoder=True) # fuse uses fixed scale since later apply does the scaling
+ shared.sd_model.unload_lora_weights()
+ timer.activate += time.time() - t0
+ except Exception as e:
+ shared.log.error(f'Load network: type=LoRA {e}')
+ if debug:
+ errors.display(e, 'LoRA')
+
+ if len(loaded_networks) > 0 and debug:
+ shared.log.debug(f'Load network: type=LoRA loaded={len(loaded_networks)} cache={list(lora_cache)}')
+
+ if recompile_model:
+ shared.log.info("Load network: type=LoRA recompiling model")
+ backup_lora_model = shared.compiled_model_state.lora_model
+ if 'Model' in shared.opts.cuda_compile:
+ shared.sd_model = sd_models_compile.compile_diffusers(shared.sd_model)
+
+ shared.compiled_model_state.lora_model = backup_lora_model
+
+ if len(loaded_networks) > 0:
+ devices.torch_gc()
+
+ timer.load = time.time() - t0
+
+
+# section: process loaded networks
+
+def network_backup_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str, wanted_names: tuple):
+ global bnb # pylint: disable=W0603
+ backup_size = 0
+ if len(loaded_networks) > 0 and network_layer_name is not None and any([net.modules.get(network_layer_name, None) for net in loaded_networks]): # noqa: C419 # pylint: disable=R1729
+ t0 = time.time()
+
+ weights_backup = getattr(self, "network_weights_backup", None)
+ if weights_backup is None and wanted_names != (): # pylint: disable=C1803
+ weight = getattr(self, 'weight', None)
+ self.network_weights_backup = None
+ if getattr(weight, "quant_type", None) in ['nf4', 'fp4']:
+ if bnb is None:
+ bnb = model_quant.load_bnb('Load network: type=LoRA', silent=True)
+ if bnb is not None:
+ with devices.inference_context():
+ if shared.opts.lora_fuse_diffusers:
+ self.network_weights_backup = True
+ else:
+ self.network_weights_backup = bnb.functional.dequantize_4bit(weight, quant_state=weight.quant_state, quant_type=weight.quant_type, blocksize=weight.blocksize,)
+ self.quant_state = weight.quant_state
+ self.quant_type = weight.quant_type
+ self.blocksize = weight.blocksize
+ else:
+ if shared.opts.lora_fuse_diffusers:
+ self.network_weights_backup = True
+ else:
+ weights_backup = weight.clone()
+ self.network_weights_backup = weights_backup.to(devices.cpu)
+ else:
+ if shared.opts.lora_fuse_diffusers:
+ self.network_weights_backup = True
+ else:
+ self.network_weights_backup = weight.clone().to(devices.cpu)
+
+ bias_backup = getattr(self, "network_bias_backup", None)
+ if bias_backup is None:
+ if getattr(self, 'bias', None) is not None:
+ if shared.opts.lora_fuse_diffusers:
+ self.network_bias_backup = True
+ else:
+ bias_backup = self.bias.clone()
+ bias_backup = bias_backup.to(devices.cpu)
+
+ if getattr(self, 'network_weights_backup', None) is not None:
+ backup_size += self.network_weights_backup.numel() * self.network_weights_backup.element_size() if isinstance(self.network_weights_backup, torch.Tensor) else 0
+ if getattr(self, 'network_bias_backup', None) is not None:
+ backup_size += self.network_bias_backup.numel() * self.network_bias_backup.element_size() if isinstance(self.network_bias_backup, torch.Tensor) else 0
+ timer.backup += time.time() - t0
+ return backup_size
+
+
+def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], network_layer_name: str):
+ if shared.opts.diffusers_offload_mode == "none":
+ try:
+ self.to(devices.device)
+ except Exception:
+ pass
+ batch_updown = None
+ batch_ex_bias = None
+ for net in loaded_networks:
+ module = net.modules.get(network_layer_name, None)
+ if module is None:
+ continue
+ try:
+ t0 = time.time()
+ try:
+ weight = self.weight.to(devices.device)
+ except Exception:
+ weight = self.weight
+ updown, ex_bias = module.calc_updown(weight)
+ if batch_updown is not None and updown is not None:
+ batch_updown += updown.to(batch_updown.device)
+ else:
+ batch_updown = updown
+ if batch_ex_bias is not None and ex_bias is not None:
+ batch_ex_bias += ex_bias.to(batch_ex_bias.device)
+ else:
+ batch_ex_bias = ex_bias
+ timer.calc += time.time() - t0
+ if shared.opts.diffusers_offload_mode == "sequential":
+ t0 = time.time()
+ if batch_updown is not None:
+ batch_updown = batch_updown.to(devices.cpu)
+ if batch_ex_bias is not None:
+ batch_ex_bias = batch_ex_bias.to(devices.cpu)
+ t1 = time.time()
+ timer.move += t1 - t0
+ except RuntimeError as e:
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
+ if debug:
+ module_name = net.modules.get(network_layer_name, None)
+ shared.log.error(f'LoRA apply weight name="{net.name}" module="{module_name}" layer="{network_layer_name}" {e}')
+ errors.display(e, 'LoRA')
+ raise RuntimeError('LoRA apply weight') from e
+ continue
+ return batch_updown, batch_ex_bias
+
+
+def network_apply_direct(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, deactivate: bool = False):
+ weights_backup = getattr(self, "network_weights_backup", False)
+ bias_backup = getattr(self, "network_bias_backup", False)
+ if not weights_backup and not bias_backup:
+ return None, None
+ t0 = time.time()
+
+ if weights_backup:
+ if updown is not None and len(self.weight.shape) == 4 and self.weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9
+ updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable
+ if updown is not None:
+ if deactivate:
+ updown *= -1
+ if getattr(self, "quant_type", None) in ['nf4', 'fp4'] and bnb is not None:
+ try: # TODO lora load: direct with bnb
+ weight = bnb.functional.dequantize_4bit(self.weight, quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize)
+ new_weight = weight.to(devices.device) + updown.to(devices.device)
+ self.weight = bnb.nn.Params4bit(new_weight, quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize)
+ except Exception:
+ # shared.log.error(f'Load network: type=LoRA quant=bnb type={self.quant_type} state={self.quant_state} blocksize={self.blocksize} {e}')
+ extra_network_lora.errors['bnb'] = extra_network_lora.errors.get('bnb', 0) + 1
+ new_weight = None
+ else:
+ try:
+ new_weight = self.weight.to(devices.device) + updown.to(devices.device)
+ except Exception:
+ new_weight = self.weight + updown
+ self.weight = torch.nn.Parameter(new_weight, requires_grad=False)
+ del new_weight
+ if hasattr(self, "qweight") and hasattr(self, "freeze"):
+ self.freeze()
+
+ if bias_backup:
+ if ex_bias is not None:
+ if deactivate:
+ ex_bias *= -1
+ new_weight = bias_backup.to(devices.device) + ex_bias.to(devices.device)
+ self.bias = torch.nn.Parameter(new_weight, requires_grad=False)
+ del new_weight
+
+ timer.apply += time.time() - t0
+ return self.weight.device, self.weight.dtype
+
+
+def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, diffusers.models.lora.LoRACompatibleLinear, diffusers.models.lora.LoRACompatibleConv], updown: torch.Tensor, ex_bias: torch.Tensor, orig_device: torch.device, deactivate: bool = False):
+ weights_backup = getattr(self, "network_weights_backup", None)
+ bias_backup = getattr(self, "network_bias_backup", None)
+ if weights_backup is None and bias_backup is None:
+ return None, None
+ t0 = time.time()
+
+ if weights_backup is not None:
+ self.weight = None
+ if updown is not None and len(weights_backup.shape) == 4 and weights_backup.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9
+ updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) # pylint: disable=not-callable
+ if updown is not None:
+ if deactivate:
+ updown *= -1
+ new_weight = weights_backup.to(devices.device) + updown.to(devices.device)
+ if getattr(self, "quant_type", None) in ['nf4', 'fp4'] and bnb is not None:
+ self.weight = bnb.nn.Params4bit(new_weight, quant_state=self.quant_state, quant_type=self.quant_type, blocksize=self.blocksize)
+ else:
+ self.weight = torch.nn.Parameter(new_weight.to(device=orig_device), requires_grad=False)
+ del new_weight
+ else:
+ self.weight = torch.nn.Parameter(weights_backup.to(device=orig_device), requires_grad=False)
+ if hasattr(self, "qweight") and hasattr(self, "freeze"):
+ self.freeze()
+
+ if bias_backup is not None:
+ self.bias = None
+ if ex_bias is not None:
+ if deactivate:
+ ex_bias *= -1
+ new_weight = bias_backup.to(devices.device) + ex_bias.to(devices.device)
+ self.bias = torch.nn.Parameter(new_weight.to(device=orig_device), requires_grad=False)
+ del new_weight
+ else:
+ self.bias = torch.nn.Parameter(bias_backup.to(device=orig_device), requires_grad=False)
+
+ timer.apply += time.time() - t0
+ return self.weight.device, self.weight.dtype
+
+
+def network_deactivate():
+ if not shared.opts.lora_fuse_diffusers:
+ return
+ t0 = time.time()
+ timer.clear()
+ sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility
+ if shared.opts.diffusers_offload_mode == "sequential":
+ sd_models.disable_offload(sd_model)
+ sd_models.move_model(sd_model, device=devices.cpu)
+ modules = {}
+ for component_name in ['text_encoder', 'text_encoder_2', 'unet', 'transformer']:
+ component = getattr(sd_model, component_name, None)
+ if component is not None and hasattr(component, 'named_modules'):
+ modules[component_name] = list(component.named_modules())
+ total = sum(len(x) for x in modules.values())
+ if len(loaded_networks) > 0:
+ pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=deactivate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
+ task = pbar.add_task(description='', total=total)
+ else:
+ task = None
+ pbar = nullcontext()
+ with devices.inference_context(), pbar:
+ applied_layers.clear()
+ weights_devices = []
+ weights_dtypes = []
+ for component in modules.keys():
+ orig_device = getattr(sd_model, component, None).device
+ for _, module in modules[component]:
+ network_layer_name = getattr(module, 'network_layer_name', None)
+ if shared.state.interrupted or network_layer_name is None:
+ if task is not None:
+ pbar.update(task, advance=1)
+ continue
+ batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name)
+ if shared.opts.lora_fuse_diffusers:
+ weights_device, weights_dtype = network_apply_direct(module, batch_updown, batch_ex_bias, deactivate=True)
+ else:
+ weights_device, weights_dtype = network_apply_weights(module, batch_updown, batch_ex_bias, orig_device, deactivate=True)
+ weights_devices.append(weights_device)
+ weights_dtypes.append(weights_dtype)
+ if batch_updown is not None or batch_ex_bias is not None:
+ applied_layers.append(network_layer_name)
+ del batch_updown, batch_ex_bias
+ module.network_current_names = ()
+ if task is not None:
+ pbar.update(task, advance=1, description=f'networks={len(loaded_networks)} modules={len(modules)} deactivate={len(applied_layers)}')
+ weights_devices, weights_dtypes = list(set([x for x in weights_devices if x is not None])), list(set([x for x in weights_dtypes if x is not None])) # noqa: C403 # pylint: disable=R1718
+ timer.deactivate = time.time() - t0
+ if debug and len(loaded_networks) > 0:
+ shared.log.debug(f'Deactivate network: type=LoRA networks={len(loaded_networks)} modules={total} deactivate={len(applied_layers)} device={weights_devices} dtype={weights_dtypes} fuse={shared.opts.lora_fuse_diffusers} time={timer.summary}')
+ modules.clear()
+ if shared.opts.diffusers_offload_mode == "sequential":
+ sd_models.set_diffuser_offload(sd_model, op="model")
+
+
+def network_activate(include=[], exclude=[]):
+ t0 = time.time()
+ sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility
+ if shared.opts.diffusers_offload_mode == "sequential":
+ sd_models.disable_offload(sd_model)
+ sd_models.move_model(sd_model, device=devices.cpu)
+ modules = {}
+ components = include if len(include) > 0 else ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'unet', 'transformer']
+ components = [x for x in components if x not in exclude]
+ for name in components:
+ component = getattr(sd_model, name, None)
+ if component is not None and hasattr(component, 'named_modules'):
+ modules[name] = list(component.named_modules())
+ total = sum(len(x) for x in modules.values())
+ if len(loaded_networks) > 0:
+ pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=activate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
+ task = pbar.add_task(description='' , total=total)
+ else:
+ task = None
+ pbar = nullcontext()
+ with devices.inference_context(), pbar:
+ wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) if len(loaded_networks) > 0 else ()
+ applied_layers.clear()
+ backup_size = 0
+ weights_devices = []
+ weights_dtypes = []
+ for component in modules.keys():
+ orig_device = getattr(sd_model, component, None).device
+ for _, module in modules[component]:
+ network_layer_name = getattr(module, 'network_layer_name', None)
+ current_names = getattr(module, "network_current_names", ())
+ if getattr(module, 'weight', None) is None or shared.state.interrupted or network_layer_name is None or current_names == wanted_names:
+ if task is not None:
+ pbar.update(task, advance=1)
+ continue
+ backup_size += network_backup_weights(module, network_layer_name, wanted_names)
+ batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name)
+ if shared.opts.lora_fuse_diffusers:
+ weights_device, weights_dtype = network_apply_direct(module, batch_updown, batch_ex_bias)
+ else:
+ weights_device, weights_dtype = network_apply_weights(module, batch_updown, batch_ex_bias, orig_device)
+ weights_devices.append(weights_device)
+ weights_dtypes.append(weights_dtype)
+ if batch_updown is not None or batch_ex_bias is not None:
+ applied_layers.append(network_layer_name)
+ del batch_updown, batch_ex_bias
+ module.network_current_names = wanted_names
+ if task is not None:
+ pbar.update(task, advance=1, description=f'networks={len(loaded_networks)} modules={total} apply={len(applied_layers)} backup={backup_size}')
+ if task is not None and len(applied_layers) == 0:
+ pbar.remove_task(task) # hide progress bar for no action
+ weights_devices, weights_dtypes = list(set([x for x in weights_devices if x is not None])), list(set([x for x in weights_dtypes if x is not None])) # noqa: C403 # pylint: disable=R1718
+ timer.activate += time.time() - t0
+ if debug and len(loaded_networks) > 0:
+ shared.log.debug(f'Load network: type=LoRA networks={len(loaded_networks)} components={components} modules={total} apply={len(applied_layers)} device={weights_devices} dtype={weights_dtypes} backup={backup_size} fuse={shared.opts.lora_fuse_diffusers} time={timer.summary}')
+ modules.clear()
+ if shared.opts.diffusers_offload_mode == "sequential":
+ sd_models.set_diffuser_offload(sd_model, op="model")
diff --git a/modules/meissonic/pipeline_img2img.py b/modules/meissonic/pipeline_img2img.py
index f26af123d..13e5c3717 100644
--- a/modules/meissonic/pipeline_img2img.py
+++ b/modules/meissonic/pipeline_img2img.py
@@ -56,9 +56,6 @@ class Img2ImgPipeline(DiffusionPipeline):
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
- # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before
- # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter
- # off the meta device. There should be a way to fix this instead of just not offloading it
_exclude_from_cpu_offload = ["vqvae"]
def __init__(
diff --git a/modules/meissonic/pipeline_inpaint.py b/modules/meissonic/pipeline_inpaint.py
index 994846fba..d405afa53 100644
--- a/modules/meissonic/pipeline_inpaint.py
+++ b/modules/meissonic/pipeline_inpaint.py
@@ -53,9 +53,6 @@ class InpaintPipeline(DiffusionPipeline):
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
- # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before
- # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter
- # off the meta device. There should be a way to fix this instead of just not offloading it
_exclude_from_cpu_offload = ["vqvae"]
def __init__(
diff --git a/modules/meissonic/transformer.py b/modules/meissonic/transformer.py
index 64f91baa2..43e77ddc7 100644
--- a/modules/meissonic/transformer.py
+++ b/modules/meissonic/transformer.py
@@ -341,11 +341,6 @@ def __call__(
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
- # if image_rotary_emb is not None: # TODO broken import
- # from .embeddings import apply_rotary_emb
- # query = apply_rotary_emb(query, image_rotary_emb)
- # key = apply_rotary_emb(key, image_rotary_emb)
-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
diff --git a/modules/memmon.py b/modules/memmon.py
index 6887e1e1c..d9fa3963d 100644
--- a/modules/memmon.py
+++ b/modules/memmon.py
@@ -42,14 +42,14 @@ def read(self):
if not self.disabled:
try:
self.data["free"], self.data["total"] = torch.cuda.mem_get_info(self.device.index if self.device.index is not None else torch.cuda.current_device())
+ self.data["used"] = self.data["total"] - self.data["free"]
torch_stats = torch.cuda.memory_stats(self.device)
- self.data["active"] = torch_stats["active.all.current"]
+ self.data["active"] = torch_stats.get("active.all.current", torch_stats["active_bytes.all.current"])
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
- self.data['retries'] = torch_stats["num_alloc_retries"]
- self.data['oom'] = torch_stats["num_ooms"]
- self.data["used"] = self.data["total"] - self.data["free"]
+ self.data['retries'] = torch_stats.get("num_alloc_retries", -1)
+ self.data['oom'] = torch_stats.get("num_ooms", -1)
except Exception:
self.disabled = True
return self.data
diff --git a/modules/memstats.py b/modules/memstats.py
index c417165a2..d43e5bbfa 100644
--- a/modules/memstats.py
+++ b/modules/memstats.py
@@ -4,13 +4,16 @@
from modules import shared, errors
fail_once = False
+mem = {}
+
+
+def gb(val: float):
+ return round(val / 1024 / 1024 / 1024, 2)
+
def memory_stats():
global fail_once # pylint: disable=global-statement
- def gb(val: float):
- return round(val / 1024 / 1024 / 1024, 2)
-
- mem = {}
+ mem.clear()
try:
process = psutil.Process(os.getpid())
res = process.memory_info()
@@ -19,10 +22,10 @@ def gb(val: float):
mem.update({ 'ram': ram })
except Exception as e:
if not fail_once:
- shared.log.error('Memory stats: {e}')
+ shared.log.error(f'Memory stats: {e}')
errors.display(e, 'Memory stats')
fail_once = True
- mem.update({ 'ram': str(e) })
+ mem.update({ 'ram': { 'error': str(e) } })
try:
s = torch.cuda.mem_get_info()
gpu = { 'used': gb(s[1] - s[0]), 'total': gb(s[1]) }
@@ -38,3 +41,18 @@ def gb(val: float):
except Exception:
pass
return mem
+
+
+def memory_cache():
+ return mem
+
+
+def ram_stats():
+ try:
+ process = psutil.Process(os.getpid())
+ res = process.memory_info()
+ ram_total = 100 * res.rss / process.memory_percent()
+ ram = { 'used': gb(res.rss), 'total': gb(ram_total) }
+ return ram
+ except Exception:
+ return { 'used': 0, 'total': 0 }
diff --git a/modules/merging/merge_methods.py b/modules/merging/merge_methods.py
index 3f704c20f..ce196b60c 100644
--- a/modules/merging/merge_methods.py
+++ b/modules/merging/merge_methods.py
@@ -4,7 +4,7 @@
import torch
from torch import Tensor
-__all__ = [
+__all__ = [ # noqa: RUF022
"weighted_sum",
"weighted_subtraction",
"tensor_sum",
diff --git a/modules/model_flux.py b/modules/model_flux.py
index 17234d9a4..f2286866e 100644
--- a/modules/model_flux.py
+++ b/modules/model_flux.py
@@ -34,7 +34,8 @@ def load_flux_quanto(checkpoint_info):
with torch.device("meta"):
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
- transformer.eval()
+ if shared.opts.diffusers_eval:
+ transformer.eval()
if transformer.dtype != devices.dtype:
try:
transformer = transformer.to(dtype=devices.dtype)
@@ -61,7 +62,8 @@ def load_flux_quanto(checkpoint_info):
with torch.device("meta"):
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
- text_encoder_2.eval()
+ if shared.opts.diffusers_eval:
+ text_encoder_2.eval()
if text_encoder_2.dtype != devices.dtype:
try:
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
@@ -108,6 +110,7 @@ def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unu
return transformer, text_encoder_2
+"""
def quant_flux_bnb(checkpoint_info, transformer, text_encoder_2):
repo_id = sd_models.path_to_repo(checkpoint_info.name)
cache_dir=shared.opts.diffusers_dir
@@ -137,18 +140,36 @@ def quant_flux_bnb(checkpoint_info, transformer, text_encoder_2):
from modules import errors
errors.display(e, 'FLUX:')
return transformer, text_encoder_2
+"""
+def load_quants(kwargs, repo_id, cache_dir):
+ if len(shared.opts.bnb_quantization) > 0:
+ quant_args = {}
+ quant_args = model_quant.create_bnb_config(quant_args)
+ quant_args = model_quant.create_ao_config(quant_args)
+ if not quant_args:
+ return kwargs
+ model_quant.load_bnb(f'Load model: type=FLUX quant={quant_args}')
+ if 'Model' in shared.opts.bnb_quantization and 'transformer' not in kwargs:
+ kwargs['transformer'] = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args)
+ shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
+ if 'Text Encoder' in shared.opts.bnb_quantization and 'text_encoder_3' not in kwargs:
+ kwargs['text_encoder_2'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args)
+ shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
+ return kwargs
+
+
+"""
def load_flux_gguf(file_path):
transformer = None
- model_te.install_gguf()
+ ggml.install_gguf()
from accelerate import init_empty_weights
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
from modules import ggml, sd_hijack_accelerate
with init_empty_weights():
- from diffusers import FluxTransformer2DModel
- config = FluxTransformer2DModel.load_config(os.path.join('configs', 'flux'), subfolder="transformer")
- transformer = FluxTransformer2DModel.from_config(config).to(devices.dtype)
+ config = diffusers.FluxTransformer2DModel.load_config(os.path.join('configs', 'flux'), subfolder="transformer")
+ transformer = diffusers.FluxTransformer2DModel.from_config(config).to(devices.dtype)
expected_state_dict_keys = list(transformer.state_dict().keys())
state_dict, stats = ggml.load_gguf_state_dict(file_path, devices.dtype)
state_dict = convert_flux_transformer_checkpoint_to_diffusers(state_dict)
@@ -160,9 +181,11 @@ def load_flux_gguf(file_path):
continue
applied += 1
sd_hijack_accelerate.hijack_set_module_tensor_simple(transformer, tensor_name=param_name, value=param, device=0)
+ transformer.gguf = 'gguf'
state_dict[param_name] = None
shared.log.debug(f'Load model: type=Unet/Transformer applied={applied} skipped={skipped} stats={stats}')
return transformer, None
+"""
def load_transformer(file_path): # triggered by opts.sd_unet change
@@ -177,7 +200,9 @@ def load_transformer(file_path): # triggered by opts.sd_unet change
}
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant={quant} dtype={devices.dtype}')
if 'gguf' in file_path.lower():
- _transformer, _text_encoder_2 = load_flux_gguf(file_path)
+ # _transformer, _text_encoder_2 = load_flux_gguf(file_path)
+ from modules import ggml
+ _transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype)
if _transformer is not None:
transformer = _transformer
elif quant == 'qint8' or quant == 'qint4':
@@ -188,13 +213,14 @@ def load_transformer(file_path): # triggered by opts.sd_unet change
_transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config)
if _transformer is not None:
transformer = _transformer
- elif 'nf4' in quant: # TODO right now this is not working for civitai published nf4 models
+ elif 'nf4' in quant: # TODO flux: fix loader for civitai nf4 models
from modules.model_flux_nf4 import load_flux_nf4
_transformer, _text_encoder_2 = load_flux_nf4(file_path)
if _transformer is not None:
transformer = _transformer
else:
diffusers_load_config = model_quant.create_bnb_config(diffusers_load_config)
+ diffusers_load_config = model_quant.create_ao_config(diffusers_load_config)
transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **diffusers_load_config)
if transformer is None:
shared.log.error('Failed to load UNet model')
@@ -223,10 +249,8 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch
if shared.opts.sd_unet != 'None':
try:
debug(f'Load model: type=FLUX unet="{shared.opts.sd_unet}"')
- _transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet])
- if _transformer is not None:
- transformer = _transformer
- else:
+ transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet])
+ if transformer is None:
shared.opts.sd_unet = 'None'
sd_unet.failed_unet.append(shared.opts.sd_unet)
except Exception as e:
@@ -294,7 +318,6 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch
# initialize pipeline with pre-loaded components
kwargs = {}
- # transformer, text_encoder_2 = quant_flux_bnb(checkpoint_info, transformer, text_encoder_2)
if transformer is not None:
kwargs['transformer'] = transformer
sd_unet.loaded_unet = shared.opts.sd_unet
@@ -306,26 +329,46 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch
model_te.loaded_te = shared.opts.sd_text_encoder
if vae is not None:
kwargs['vae'] = vae
- shared.log.debug(f'Load model: type=FLUX preloaded={list(kwargs)}')
if repo_id == 'sayakpaul/flux.1-dev-nf4':
repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json
+ if 'Fill' in repo_id:
+ cls = diffusers.FluxFillPipeline
+ elif 'Canny' in repo_id:
+ cls = diffusers.FluxControlPipeline
+ elif 'Depth' in repo_id:
+ cls = diffusers.FluxControlPipeline
+ else:
+ cls = diffusers.FluxPipeline
+ shared.log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
for c in kwargs:
+ if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None:
+ shared.log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}')
if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
- shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
- kwargs[c] = kwargs[c].to(dtype=devices.dtype)
+ try:
+ kwargs[c] = kwargs[c].to(dtype=devices.dtype)
+ shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
+ except Exception:
+ pass
- allow_bnb = 'gguf' not in (sd_unet.loaded_unet or '')
- kwargs = model_quant.create_bnb_config(kwargs, allow_bnb)
- if checkpoint_info.path.endswith('.safetensors') and os.path.isfile(checkpoint_info.path):
- pipe = diffusers.FluxPipeline.from_single_file(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
+ allow_quant = 'gguf' not in (sd_unet.loaded_unet or '')
+ fn = checkpoint_info.path
+ if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
+ # transformer, text_encoder_2 = quant_flux_bnb(checkpoint_info, transformer, text_encoder_2)
+ kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir)
+ kwargs = model_quant.create_bnb_config(kwargs, allow_quant)
+ kwargs = model_quant.create_ao_config(kwargs, allow_quant)
+ if fn.endswith('.safetensors') and os.path.isfile(fn):
+ pipe = diffusers.FluxPipeline.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
else:
- pipe = diffusers.FluxPipeline.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
+ pipe = cls.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
# release memory
transformer = None
text_encoder_1 = None
text_encoder_2 = None
vae = None
+ for k in kwargs.keys():
+ kwargs[k] = None
devices.torch_gc()
return pipe
diff --git a/modules/model_omnigen.py b/modules/model_omnigen.py
index 64c99ddd7..a08ad4ed5 100644
--- a/modules/model_omnigen.py
+++ b/modules/model_omnigen.py
@@ -17,7 +17,8 @@ def load_omnigen(checkpoint_info, diffusers_load_config={}): # pylint: disable=u
pipe.separate_cfg_infer = True
pipe.use_kv_cache = False
pipe.model.to(device=devices.device, dtype=devices.dtype)
- pipe.model.eval()
+ if shared.opts.diffusers_eval:
+ pipe.model.eval()
pipe.vae.to(devices.device, dtype=devices.dtype)
devices.torch_gc()
diff --git a/modules/model_quant.py b/modules/model_quant.py
index 0e7bdd4b3..10e5eb99e 100644
--- a/modules/model_quant.py
+++ b/modules/model_quant.py
@@ -5,6 +5,7 @@
bnb = None
quanto = None
+ao = None
def create_bnb_config(kwargs = None, allow_bnb: bool = True):
@@ -12,6 +13,8 @@ def create_bnb_config(kwargs = None, allow_bnb: bool = True):
if len(shared.opts.bnb_quantization) > 0 and allow_bnb:
if 'Model' in shared.opts.bnb_quantization:
load_bnb()
+ if bnb is None:
+ return kwargs
bnb_config = diffusers.BitsAndBytesConfig(
load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'],
load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'],
@@ -28,11 +31,49 @@ def create_bnb_config(kwargs = None, allow_bnb: bool = True):
return kwargs
+def create_ao_config(kwargs = None, allow_ao: bool = True):
+ from modules import shared
+ if len(shared.opts.torchao_quantization) > 0 and shared.opts.torchao_quantization_mode == 'pre' and allow_ao:
+ if 'Model' in shared.opts.torchao_quantization:
+ load_torchao()
+ if ao is None:
+ return kwargs
+ diffusers.utils.import_utils.is_torchao_available = lambda: True
+ ao_config = diffusers.TorchAoConfig(shared.opts.torchao_quantization_type)
+ shared.log.debug(f'Quantization: module=all type=torchao dtype={shared.opts.torchao_quantization_type}')
+ if kwargs is None:
+ return ao_config
+ else:
+ kwargs['quantization_config'] = ao_config
+ return kwargs
+ return kwargs
+
+
+def load_torchao(msg='', silent=False):
+ global ao # pylint: disable=global-statement
+ if ao is not None:
+ return ao
+ install('torchao==0.7.0', quiet=True)
+ try:
+ import torchao
+ ao = torchao
+ fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
+ log.debug(f'Quantization: type=torchao version={ao.__version__} fn={fn}') # pylint: disable=protected-access
+ return ao
+ except Exception as e:
+ if len(msg) > 0:
+ log.error(f"{msg} failed to import torchao: {e}")
+ ao = None
+ if not silent:
+ raise
+ return None
+
+
def load_bnb(msg='', silent=False):
global bnb # pylint: disable=global-statement
if bnb is not None:
return bnb
- install('bitsandbytes', quiet=True)
+ install('bitsandbytes==0.45.0', quiet=True)
try:
import bitsandbytes
bnb = bitsandbytes
@@ -51,15 +92,18 @@ def load_bnb(msg='', silent=False):
def load_quanto(msg='', silent=False):
+ from modules import shared
global quanto # pylint: disable=global-statement
if quanto is not None:
return quanto
- install('optimum-quanto', quiet=True)
+ install('optimum-quanto==0.2.6', quiet=True)
try:
from optimum import quanto as optimum_quanto # pylint: disable=no-name-in-module
quanto = optimum_quanto
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Quantization: type=quanto version={quanto.__version__} fn={fn}') # pylint: disable=protected-access
+ if shared.opts.diffusers_offload_mode != 'none':
+ shared.log.error(f'Quantization: type=quanto offload={shared.opts.diffusers_offload_mode} not supported')
return quanto
except Exception as e:
if len(msg) > 0:
diff --git a/modules/model_sana.py b/modules/model_sana.py
new file mode 100644
index 000000000..b9f56c7c6
--- /dev/null
+++ b/modules/model_sana.py
@@ -0,0 +1,79 @@
+import os
+import time
+import torch
+import diffusers
+import transformers
+from modules import shared, sd_models, devices, modelloader, model_quant
+
+
+def load_quants(kwargs, repo_id, cache_dir):
+ if len(shared.opts.bnb_quantization) > 0:
+ quant_args = {}
+ quant_args = model_quant.create_bnb_config(quant_args)
+ quant_args = model_quant.create_ao_config(quant_args)
+ load_args = kwargs.copy()
+ if not quant_args:
+ return kwargs
+ model_quant.load_bnb(f'Load model: type=SD3 quant={quant_args} args={load_args}')
+ if 'Model' in shared.opts.bnb_quantization and 'transformer' not in kwargs:
+ kwargs['transformer'] = diffusers.models.SanaTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, **load_args, **quant_args)
+ shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
+ if 'Text Encoder' in shared.opts.bnb_quantization and 'text_encoder_3' not in kwargs:
+ kwargs['text_encoder_3'] = transformers.AutoModelForCausalLM.from_pretrained(repo_id, subfolder="text_encoder", cache_dir=cache_dir, **load_args, **quant_args)
+ shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
+ return kwargs
+
+
+def load_sana(checkpoint_info, kwargs={}):
+ modelloader.hf_login()
+
+ fn = checkpoint_info if isinstance(checkpoint_info, str) else checkpoint_info.path
+ repo_id = sd_models.path_to_repo(fn)
+ kwargs.pop('load_connected_pipeline', None)
+ kwargs.pop('safety_checker', None)
+ kwargs.pop('requires_safety_checker', None)
+ kwargs.pop('torch_dtype', None)
+
+ if not repo_id.endswith('_diffusers'):
+ repo_id = f'{repo_id}_diffusers'
+ if devices.dtype == torch.bfloat16 and 'BF16' not in repo_id:
+ repo_id = repo_id.replace('_diffusers', '_BF16_diffusers')
+
+ if 'Sana_1600M' in repo_id:
+ if devices.dtype == torch.bfloat16 or 'BF16' in repo_id:
+ if 'BF16' not in repo_id:
+ repo_id = repo_id.replace('_diffusers', '_BF16_diffusers')
+ kwargs['variant'] = 'bf16'
+ kwargs['torch_dtype'] = devices.dtype
+ else:
+ kwargs['variant'] = 'fp16'
+ if 'Sana_600M' in repo_id:
+ kwargs['variant'] = 'fp16'
+
+ if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
+ kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir)
+ # kwargs = model_quant.create_bnb_config(kwargs)
+ # kwargs = model_quant.create_ao_config(kwargs)
+ shared.log.debug(f'Load model: type=Sana repo="{repo_id}" args={list(kwargs)}')
+ t0 = time.time()
+ pipe = diffusers.SanaPipeline.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs)
+ if devices.dtype == torch.bfloat16 or devices.dtype == torch.float32:
+ if 'transformer' not in kwargs:
+ pipe.transformer = pipe.transformer.to(dtype=devices.dtype)
+ if 'text_encoder' not in kwargs:
+ pipe.text_encoder = pipe.text_encoder.to(dtype=devices.dtype)
+ pipe.vae = pipe.vae.to(dtype=devices.dtype)
+ if devices.dtype == torch.float16:
+ if 'transformer' not in kwargs:
+ pipe.transformer = pipe.transformer.to(dtype=devices.dtype)
+ if 'text_encoder' not in kwargs:
+ pipe.text_encoder = pipe.text_encoder.to(dtype=torch.float32) # gemma2 does not support fp16
+ pipe.vae = pipe.vae.to(dtype=torch.float32) # dc-ae often overflows in fp16
+ if shared.opts.diffusers_eval:
+ pipe.text_encoder.eval()
+ pipe.transformer.eval()
+ t1 = time.time()
+ shared.log.debug(f'Load model: type=Sana target={devices.dtype} te={pipe.text_encoder.dtype} transformer={pipe.transformer.dtype} vae={pipe.vae.dtype} time={t1-t0:.2f}')
+
+ devices.torch_gc()
+ return pipe
diff --git a/modules/model_sd3.py b/modules/model_sd3.py
index b9d579085..1834c573e 100644
--- a/modules/model_sd3.py
+++ b/modules/model_sd3.py
@@ -1,7 +1,7 @@
import os
import diffusers
import transformers
-from modules import shared, devices, sd_models, sd_unet, model_te, model_quant, model_tools
+from modules import shared, devices, sd_models, sd_unet, model_quant, model_tools
def load_overrides(kwargs, cache_dir):
@@ -13,7 +13,9 @@ def load_overrides(kwargs, cache_dir):
sd_unet.loaded_unet = shared.opts.sd_unet
shared.log.debug(f'Load model: type=SD3 unet="{shared.opts.sd_unet}" fmt=safetensors')
elif fn.endswith('.gguf'):
- kwargs = load_gguf(kwargs, fn)
+ from modules import ggml
+ # kwargs = load_gguf(kwargs, fn)
+ kwargs['transformer'] = ggml.load_gguf(fn, cls=diffusers.SD3Transformer2DModel, compute_dtype=devices.dtype)
sd_unet.loaded_unet = shared.opts.sd_unet
shared.log.debug(f'Load model: type=SD3 unet="{shared.opts.sd_unet}" fmt=gguf')
except Exception as e:
@@ -51,19 +53,17 @@ def load_overrides(kwargs, cache_dir):
def load_quants(kwargs, repo_id, cache_dir):
if len(shared.opts.bnb_quantization) > 0:
- model_quant.load_bnb('Load model: type=SD3')
- bnb_config = diffusers.BitsAndBytesConfig(
- load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'],
- load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'],
- bnb_4bit_quant_storage=shared.opts.bnb_quantization_storage,
- bnb_4bit_quant_type=shared.opts.bnb_quantization_type,
- bnb_4bit_compute_dtype=devices.dtype
- )
+ quant_args = {}
+ quant_args = model_quant.create_bnb_config(quant_args)
+ quant_args = model_quant.create_ao_config(quant_args)
+ if not quant_args:
+ return kwargs
+ model_quant.load_bnb(f'Load model: type=SD3 quant={quant_args}')
if 'Model' in shared.opts.bnb_quantization and 'transformer' not in kwargs:
- kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype)
+ kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args)
shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
if 'Text Encoder' in shared.opts.bnb_quantization and 'text_encoder_3' not in kwargs:
- kwargs['text_encoder_3'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_3", variant='fp16', cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype)
+ kwargs['text_encoder_3'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_3", variant='fp16', cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args)
shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
return kwargs
@@ -92,8 +92,9 @@ def load_missing(kwargs, fn, cache_dir):
return kwargs
+"""
def load_gguf(kwargs, fn):
- model_te.install_gguf()
+ ggml.install_gguf()
from accelerate import init_empty_weights
from diffusers.loaders.single_file_utils import convert_sd3_transformer_checkpoint_to_diffusers
from modules import ggml, sd_hijack_accelerate
@@ -110,10 +111,12 @@ def load_gguf(kwargs, fn):
continue
applied += 1
sd_hijack_accelerate.hijack_set_module_tensor_simple(transformer, tensor_name=param_name, value=param, device=0)
+ transformer.gguf = 'gguf'
state_dict[param_name] = None
shared.log.debug(f'Load model: type=Unet/Transformer applied={applied} skipped={skipped} stats={stats} compute={devices.dtype}')
kwargs['transformer'] = transformer
return kwargs
+"""
def load_sd3(checkpoint_info, cache_dir=None, config=None):
@@ -127,7 +130,7 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None):
kwargs = {}
kwargs = load_overrides(kwargs, cache_dir)
- if fn is None or not os.path.exists(fn):
+ if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
kwargs = load_quants(kwargs, repo_id, cache_dir)
loader = diffusers.StableDiffusion3Pipeline.from_pretrained
@@ -141,7 +144,9 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None):
# kwargs = load_missing(kwargs, fn, cache_dir)
repo_id = fn
elif fn.endswith('.gguf'):
- kwargs = load_gguf(kwargs, fn)
+ from modules import ggml
+ kwargs['transformer'] = ggml.load_gguf(fn, cls=diffusers.SD3Transformer2DModel, compute_dtype=devices.dtype)
+ # kwargs = load_gguf(kwargs, fn)
kwargs = load_missing(kwargs, fn, cache_dir)
kwargs['variant'] = 'fp16'
else:
@@ -150,6 +155,7 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None):
shared.log.debug(f'Load model: type=SD3 kwargs={list(kwargs)} repo="{repo_id}"')
kwargs = model_quant.create_bnb_config(kwargs)
+ kwargs = model_quant.create_ao_config(kwargs)
pipe = loader(
repo_id,
torch_dtype=devices.dtype,
diff --git a/modules/model_stablecascade.py b/modules/model_stablecascade.py
index 6c23ea00a..2a7739e55 100644
--- a/modules/model_stablecascade.py
+++ b/modules/model_stablecascade.py
@@ -187,8 +187,7 @@ def __call__(
callback_on_step_end=None,
callback_on_step_end_tensor_inputs=["latents"],
):
- if shared.opts.diffusers_offload_mode == "balanced":
- shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
# 0. Define commonly used variables
self.guidance_scale = guidance_scale
self.do_classifier_free_guidance = self.guidance_scale > 1
@@ -330,14 +329,11 @@ def __call__(
elif output_type == "pil":
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
images = self.numpy_to_pil(images)
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
else:
images = latents
- # Offload all models
- if shared.opts.diffusers_offload_mode == "balanced":
- shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
- else:
- self.maybe_free_model_hooks()
+ self.maybe_free_model_hooks()
if not return_dict:
return images
diff --git a/modules/model_te.py b/modules/model_te.py
index 606cdb86d..16bb6d222 100644
--- a/modules/model_te.py
+++ b/modules/model_te.py
@@ -12,20 +12,6 @@
loaded_te = None
-def install_gguf():
- # pip install git+https://github.com/junejae/transformers@feature/t5-gguf
- install('gguf', quiet=True)
- # https://github.com/ggerganov/llama.cpp/issues/9566
- import gguf
- scripts_dir = os.path.join(os.path.dirname(gguf.__file__), '..', 'scripts')
- if os.path.exists(scripts_dir):
- os.rename(scripts_dir, scripts_dir + '_gguf')
- # monkey patch transformers so they detect gguf pacakge correctly
- import importlib
- transformers.utils.import_utils._is_gguf_available = True # pylint: disable=protected-access
- transformers.utils.import_utils._gguf_version = importlib.metadata.version('gguf') # pylint: disable=protected-access
-
-
def load_t5(name=None, cache_dir=None):
global loaded_te # pylint: disable=global-statement
if name is None:
@@ -34,8 +20,9 @@ def load_t5(name=None, cache_dir=None):
modelloader.hf_login()
repo_id = 'stabilityai/stable-diffusion-3-medium-diffusers'
fn = te_dict.get(name) if name in te_dict else None
- if fn is not None and 'gguf' in name.lower():
- install_gguf()
+ if fn is not None and name.lower().endswith('gguf'):
+ from modules import ggml
+ ggml.install_gguf()
with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
t5 = transformers.T5EncoderModel.from_pretrained(None, gguf_file=fn, config=t5_config, device_map="auto", cache_dir=cache_dir, torch_dtype=devices.dtype)
@@ -52,7 +39,8 @@ def load_t5(name=None, cache_dir=None):
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
param = param.to(devices.dtype)
set_module_tensor_to_device(t5, param_name, device=0, value=param)
- t5.eval()
+ if shared.opts.diffusers_eval:
+ t5.eval()
if t5.dtype != devices.dtype:
try:
t5 = t5.to(dtype=devices.dtype)
diff --git a/modules/model_tools.py b/modules/model_tools.py
index 07cd61b6e..fdeda5c2a 100644
--- a/modules/model_tools.py
+++ b/modules/model_tools.py
@@ -69,11 +69,13 @@ def load_modules(repo_id: str, params: dict):
subfolder = 'text_encoder_2'
if cls == transformers.T5EncoderModel: # t5-xxl
subfolder = 'text_encoder_3'
- kwargs['quantization_config'] = model_quant.create_bnb_config()
+ kwargs = model_quant.create_bnb_config(kwargs)
+ kwargs = model_quant.create_ao_config(kwargs)
kwargs['variant'] = 'fp16'
if cls == diffusers.SD3Transformer2DModel:
subfolder = 'transformer'
- kwargs['quantization_config'] = model_quant.create_bnb_config()
+ kwargs = model_quant.create_bnb_config(kwargs)
+ kwargs = model_quant.create_ao_config(kwargs)
if subfolder is None:
continue
shared.log.debug(f'Load: module={name} class={cls.__name__} repo={repo_id} location={subfolder}')
diff --git a/modules/modeldata.py b/modules/modeldata.py
index 604ff4623..4b7ec1776 100644
--- a/modules/modeldata.py
+++ b/modules/modeldata.py
@@ -3,6 +3,45 @@
from modules import shared, errors
+def get_model_type(pipe):
+ name = pipe.__class__.__name__
+ if not shared.native:
+ model_type = 'ldm'
+ elif "StableDiffusion3" in name:
+ model_type = 'sd3'
+ elif "StableDiffusionXL" in name:
+ model_type = 'sdxl'
+ elif "StableDiffusion" in name:
+ model_type = 'sd'
+ elif "LatentConsistencyModel" in name:
+ model_type = 'sd' # lcm is compatible with sd
+ elif "InstaFlowPipeline" in name:
+ model_type = 'sd' # instaflow is compatible with sd
+ elif "AnimateDiffPipeline" in name:
+ model_type = 'sd' # animatediff is compatible with sd
+ elif "Kandinsky" in name:
+ model_type = 'kandinsky'
+ elif "HunyuanDiT" in name:
+ model_type = 'hunyuandit'
+ elif "Cascade" in name:
+ model_type = 'sc'
+ elif "AuraFlow" in name:
+ model_type = 'auraflow'
+ elif "Flux" in name:
+ model_type = 'f1'
+ elif "Lumina" in name:
+ model_type = 'lumina'
+ elif "OmniGen" in name:
+ model_type = 'omnigen'
+ elif "CogVideo" in name:
+ model_type = 'cogvideox'
+ elif "Sana" in name:
+ model_type = 'sana'
+ else:
+ model_type = name
+ return model_type
+
+
class ModelData:
def __init__(self):
self.sd_model = None
@@ -82,36 +121,7 @@ def sd_model_type(self):
if modules.sd_models.model_data.sd_model is None:
model_type = 'none'
return model_type
- if not shared.native:
- model_type = 'ldm'
- elif "StableDiffusion3" in self.sd_model.__class__.__name__:
- model_type = 'sd3'
- elif "StableDiffusionXL" in self.sd_model.__class__.__name__:
- model_type = 'sdxl'
- elif "StableDiffusion" in self.sd_model.__class__.__name__:
- model_type = 'sd'
- elif "LatentConsistencyModel" in self.sd_model.__class__.__name__:
- model_type = 'sd' # lcm is compatible with sd
- elif "InstaFlowPipeline" in self.sd_model.__class__.__name__:
- model_type = 'sd' # instaflow is compatible with sd
- elif "AnimateDiffPipeline" in self.sd_model.__class__.__name__:
- model_type = 'sd' # animatediff is compatible with sd
- elif "Kandinsky" in self.sd_model.__class__.__name__:
- model_type = 'kandinsky'
- elif "HunyuanDiT" in self.sd_model.__class__.__name__:
- model_type = 'hunyuandit'
- elif "Cascade" in self.sd_model.__class__.__name__:
- model_type = 'sc'
- elif "AuraFlow" in self.sd_model.__class__.__name__:
- model_type = 'auraflow'
- elif "Flux" in self.sd_model.__class__.__name__:
- model_type = 'f1'
- elif "OmniGen" in self.sd_model.__class__.__name__:
- model_type = 'omnigen'
- elif "CogVideo" in self.sd_model.__class__.__name__:
- model_type = 'cogvideox'
- else:
- model_type = self.sd_model.__class__.__name__
+ model_type = get_model_type(self.sd_model)
except Exception:
model_type = 'unknown'
return model_type
@@ -123,18 +133,7 @@ def sd_refiner_type(self):
if modules.sd_models.model_data.sd_refiner is None:
model_type = 'none'
return model_type
- if not shared.native:
- model_type = 'ldm'
- elif "StableDiffusion3" in self.sd_refiner.__class__.__name__:
- model_type = 'sd3'
- elif "StableDiffusionXL" in self.sd_refiner.__class__.__name__:
- model_type = 'sdxl'
- elif "StableDiffusion" in self.sd_refiner.__class__.__name__:
- model_type = 'sd'
- elif "Kandinsky" in self.sd_refiner.__class__.__name__:
- model_type = 'kandinsky'
- else:
- model_type = self.sd_refiner.__class__.__name__
+ model_type = get_model_type(self.sd_refiner)
except Exception:
model_type = 'unknown'
return model_type
diff --git a/modules/modelloader.py b/modules/modelloader.py
index ce36a739b..b022b4fc6 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -267,7 +267,7 @@ def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config
def load_diffusers_models(clear=True):
excluded_models = []
- t0 = time.time()
+ # t0 = time.time()
place = shared.opts.diffusers_dir
if place is None or len(place) == 0 or not os.path.isdir(place):
place = os.path.join(models_path, 'Diffusers')
@@ -316,7 +316,7 @@ def load_diffusers_models(clear=True):
debug(f'Error analyzing diffusers model: "{folder}" {e}')
except Exception as e:
shared.log.error(f"Error listing diffusers: {place} {e}")
- shared.log.debug(f'Scanning diffusers cache: folder="{place}" items={len(list(diffuser_repos))} time={time.time()-t0:.2f}')
+ # shared.log.debug(f'Scanning diffusers cache: folder="{place}" items={len(list(diffuser_repos))} time={time.time()-t0:.2f}')
return diffuser_repos
@@ -326,6 +326,9 @@ def find_diffuser(name: str, full=False):
return [repo[0]['name']]
hf_api = hf.HfApi()
models = list(hf_api.list_models(model_name=name, library=['diffusers'], full=True, limit=20, sort="downloads", direction=-1))
+ if len(models) == 0:
+ models = list(hf_api.list_models(model_name=name, full=True, limit=20, sort="downloads", direction=-1)) # widen search
+ models = [m for m in models if m.id.startswith(name)] # filter exact
shared.log.debug(f'Searching diffusers models: {name} {len(models) > 0}')
if len(models) > 0:
if not full:
diff --git a/modules/omnigen/utils.py b/modules/omnigen/utils.py
index 5483d6eab..0304e1732 100644
--- a/modules/omnigen/utils.py
+++ b/modules/omnigen/utils.py
@@ -25,7 +25,6 @@ def update_ema(ema_model, model, decay=0.9999):
"""
ema_params = dict(ema_model.named_parameters())
for name, param in model.named_parameters():
- # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
diff --git a/modules/onnx_impl/pipelines/__init__.py b/modules/onnx_impl/pipelines/__init__.py
index ca1ddd2f7..a11b07fc7 100644
--- a/modules/onnx_impl/pipelines/__init__.py
+++ b/modules/onnx_impl/pipelines/__init__.py
@@ -241,7 +241,7 @@ def run_olive(self, submodels: List[str], in_dir: os.PathLike, out_dir: os.PathL
for i in range(len(flow)):
flow[i] = flow[i].replace("AutoExecutionProvider", shared.opts.onnx_execution_provider)
olive_config["input_model"]["config"]["model_path"] = os.path.abspath(os.path.join(in_dir, submodel, "model.onnx"))
- olive_config["systems"]["local_system"]["config"]["accelerators"][0]["device"] = "cpu" if shared.opts.onnx_execution_provider == ExecutionProvider.CPU else "gpu" # TODO: npu
+ olive_config["systems"]["local_system"]["config"]["accelerators"][0]["device"] = "cpu" if shared.opts.onnx_execution_provider == ExecutionProvider.CPU else "gpu"
olive_config["systems"]["local_system"]["config"]["accelerators"][0]["execution_providers"] = [shared.opts.onnx_execution_provider]
for pass_key in olive_config["passes"]:
diff --git a/modules/pag/__init__.py b/modules/pag/__init__.py
index 29cdee8ca..a72f7825d 100644
--- a/modules/pag/__init__.py
+++ b/modules/pag/__init__.py
@@ -12,33 +12,41 @@ def apply(p: processing.StableDiffusionProcessing): # pylint: disable=arguments-
global orig_pipeline # pylint: disable=global-statement
if not shared.native:
return None
- c = shared.sd_model.__class__ if shared.sd_loaded else None
- if c == StableDiffusionPAGPipeline or c == StableDiffusionXLPAGPipeline:
- unapply()
+ cls = shared.sd_model.__class__ if shared.sd_loaded else None
+ if cls == StableDiffusionPAGPipeline or cls == StableDiffusionXLPAGPipeline:
+ cls = unapply()
if p.pag_scale == 0:
return
- if sd_models.get_diffusers_task(shared.sd_model) != sd_models.DiffusersTaskType.TEXT_2_IMAGE:
- shared.log.warning(f'PAG: pipeline={c} not implemented')
- return None
- if detect.is_sd15(c):
+ if 'PAG' in cls.__name__:
+ pass
+ elif detect.is_sd15(cls):
+ if sd_models.get_diffusers_task(shared.sd_model) != sd_models.DiffusersTaskType.TEXT_2_IMAGE:
+ shared.log.warning(f'PAG: pipeline={cls.__name__} not implemented')
+ return None
orig_pipeline = shared.sd_model
shared.sd_model = sd_models.switch_pipe(StableDiffusionPAGPipeline, shared.sd_model)
- elif detect.is_sdxl(c):
+ elif detect.is_sdxl(cls):
+ if sd_models.get_diffusers_task(shared.sd_model) != sd_models.DiffusersTaskType.TEXT_2_IMAGE:
+ shared.log.warning(f'PAG: pipeline={cls.__name__} not implemented')
+ return None
orig_pipeline = shared.sd_model
shared.sd_model = sd_models.switch_pipe(StableDiffusionXLPAGPipeline, shared.sd_model)
+ elif detect.is_f1(cls):
+ p.task_args['true_cfg_scale'] = p.pag_scale
else:
- shared.log.warning(f'PAG: pipeline={c} required={StableDiffusionPipeline.__name__}')
+ shared.log.warning(f'PAG: pipeline={cls.__name__} required={StableDiffusionPipeline.__name__}')
return None
p.task_args['pag_scale'] = p.pag_scale
p.task_args['pag_adaptive_scaling'] = p.pag_adaptive
+ p.task_args['pag_adaptive_scale'] = p.pag_adaptive
pag_applied_layers = shared.opts.pag_apply_layers
pag_applied_layers_index = pag_applied_layers.split() if len(pag_applied_layers) > 0 else []
pag_applied_layers_index = [p.strip() for p in pag_applied_layers_index]
p.task_args['pag_applied_layers_index'] = pag_applied_layers_index if len(pag_applied_layers_index) > 0 else ['m0'] # Available layers: d[0-5], m[0], u[0-8]
p.extra_generation_params["PAG scale"] = p.pag_scale
p.extra_generation_params["PAG adaptive"] = p.pag_adaptive
- shared.log.debug(f'{c}: args={p.task_args}')
+ # shared.log.debug(f'{c}: args={p.task_args}')
def unapply():
@@ -46,3 +54,4 @@ def unapply():
if orig_pipeline is not None:
shared.sd_model = orig_pipeline
orig_pipeline = None
+ return shared.sd_model.__class__
diff --git a/modules/pag/pipe_sd.py b/modules/pag/pipe_sd.py
index 16f9b5319..11f4fb0cf 100644
--- a/modules/pag/pipe_sd.py
+++ b/modules/pag/pipe_sd.py
@@ -104,7 +104,6 @@ def __call__(
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
@@ -219,7 +218,6 @@ def __call__(
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
diff --git a/modules/pag/pipe_sdxl.py b/modules/pag/pipe_sdxl.py
index 82ae06c07..3a47af3e5 100644
--- a/modules/pag/pipe_sdxl.py
+++ b/modules/pag/pipe_sdxl.py
@@ -124,7 +124,6 @@ def __call__(
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
@@ -239,7 +238,6 @@ def __call__(
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
diff --git a/modules/postprocess/yolo.py b/modules/postprocess/yolo.py
index 0e733d419..8fc203280 100644
--- a/modules/postprocess/yolo.py
+++ b/modules/postprocess/yolo.py
@@ -72,7 +72,7 @@ def predict(
imgsz: int = 640,
half: bool = True,
device = devices.device,
- augment: bool = True,
+ augment: bool = shared.opts.detailer_augment,
agnostic: bool = False,
retina: bool = False,
mask: bool = True,
@@ -300,7 +300,8 @@ def restore(self, np_image, p: processing.StableDiffusionProcessing = None):
# combined.save('/tmp/item.png')
p.image_mask = Image.fromarray(p.image_mask)
- shared.log.debug(f'Detailer processed: models={models_used}')
+ if len(models_used) > 0:
+ shared.log.debug(f'Detailer processed: models={models_used}')
return np_image
def ui(self, tab: str):
diff --git a/modules/processing.py b/modules/processing.py
index 0d557e64e..c85938268 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -118,6 +118,9 @@ def js(self):
def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
+ def __str___(self):
+ return f'{self.__class__.__name__}: {self.__dict__}'
+
def process_images(p: StableDiffusionProcessing) -> Processed:
timer.process.reset()
@@ -176,6 +179,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
timer.process.record('pre')
if shared.cmd_opts.profile:
+ timer.startup.profile = True
+ timer.process.profile = True
with context_hypertile_vae(p), context_hypertile_unet(p):
import torch.profiler # pylint: disable=redefined-outer-name
activities=[torch.profiler.ProfilerActivity.CPU]
@@ -286,7 +291,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
t0 = time.time()
if not hasattr(p, 'skip_init'):
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
- extra_network_data = None
debug(f'Processing inner: args={vars(p)}')
for n in range(p.n_iter):
pag.apply(p)
@@ -311,9 +315,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
if len(p.prompts) == 0:
break
- p.prompts, extra_network_data = extra_networks.parse_prompts(p.prompts)
- if not p.disable_extra_networks:
- extra_networks.activate(p, extra_network_data)
+ p.prompts, p.network_data = extra_networks.parse_prompts(p.prompts)
+ if not shared.native:
+ extra_networks.activate(p, p.network_data)
if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner):
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
@@ -323,7 +327,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
processed = p.scripts.process_images(p)
if processed is not None:
samples = processed.images
- infotexts = processed.infotexts
+ infotexts += processed.infotexts
if samples is None:
if not shared.native:
from modules.processing_original import process_original
@@ -353,7 +357,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
for i, sample in enumerate(samples):
debug(f'Processing result: index={i+1}/{len(samples)} iteration={n+1}/{p.n_iter}')
p.batch_index = i
- if type(sample) == Image.Image:
+ if isinstance(sample, Image.Image) or (isinstance(sample, list) and isinstance(sample[0], Image.Image)):
image = sample
sample = np.array(sample)
else:
@@ -393,16 +397,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.opts.mask_apply_overlay:
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
- if len(infotexts) > i:
- info = infotexts[i]
+ info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i, all_negative_prompts=p.negative_prompts)
+ infotexts.append(info)
+ if isinstance(image, list):
+ for img in image:
+ img.info["parameters"] = info
+ output_images = image
else:
- info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i, all_negative_prompts=p.negative_prompts)
- infotexts.append(info)
- image.info["parameters"] = info
- output_images.append(image)
+ image.info["parameters"] = info
+ output_images.append(image)
if shared.opts.samples_save and not p.do_not_save_samples and p.outpath_samples is not None:
info = create_infotext(p, p.prompts, p.seeds, p.subseeds, index=i)
- images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=info, p=p) # main save image
+ if isinstance(image, list):
+ for img in image:
+ images.save_image(img, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=info, p=p) # main save image
+ else:
+ images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=info, p=p) # main save image
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([shared.opts.save_mask, shared.opts.save_mask_composite, shared.opts.return_mask, shared.opts.return_mask_composite]):
image_mask = p.mask_for_overlay.convert('RGB')
image1 = image.convert('RGBA').convert('RGBa')
@@ -420,6 +430,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
timer.process.record('post')
del samples
+
+ if not shared.native:
+ extra_networks.deactivate(p, p.network_data)
+
devices.torch_gc()
if hasattr(shared.sd_model, 'restore_pipeline') and shared.sd_model.restore_pipeline is not None:
@@ -446,10 +460,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.native:
from modules import ipadapter
- ipadapter.unapply(shared.sd_model)
-
- if not p.disable_extra_networks:
- extra_networks.deactivate(p, extra_network_data)
+ ipadapter.unapply(shared.sd_model, unload=getattr(p, 'ip_adapter_unload', False))
if shared.opts.include_mask:
if shared.opts.mask_apply_overlay and p.overlay_images is not None and len(p.overlay_images):
@@ -475,5 +486,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None and isinstance(p.scripts, scripts.ScriptRunner) and not (shared.state.interrupted or shared.state.skipped):
p.scripts.postprocess(p, processed)
timer.process.record('post')
- shared.log.info(f'Processed: images={len(output_images)} its={(p.steps * len(output_images)) / (t1 - t0):.2f} time={t1-t0:.2f} timers={timer.process.dct(min_time=0.02)} memory={memstats.memory_stats()}')
+ if not p.disable_extra_networks:
+ shared.log.info(f'Processed: images={len(output_images)} its={(p.steps * len(output_images)) / (t1 - t0):.2f} time={t1-t0:.2f} timers={timer.process.dct()} memory={memstats.memory_stats()}')
+
+ devices.torch_gc(force=True, reason='final')
return processed
diff --git a/modules/processing_args.py b/modules/processing_args.py
index ff766ec04..5ca6cfcd3 100644
--- a/modules/processing_args.py
+++ b/modules/processing_args.py
@@ -6,13 +6,14 @@
import inspect
import torch
import numpy as np
-from modules import shared, errors, sd_models, processing, processing_vae, processing_helpers, sd_hijack_hypertile, prompt_parser_diffusers, timer
+from modules import shared, errors, sd_models, processing, processing_vae, processing_helpers, sd_hijack_hypertile, prompt_parser_diffusers, timer, extra_networks
from modules.processing_callbacks import diffusers_callback_legacy, diffusers_callback, set_callbacks_p
from modules.processing_helpers import resize_hires, fix_prompts, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, get_generator, set_latents, apply_circular # pylint: disable=unused-import
from modules.api import helpers
-debug = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None
+debug_enabled = os.environ.get('SD_DIFFUSERS_DEBUG', None)
+debug_log = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None
def task_specific_kwargs(p, model):
@@ -22,7 +23,7 @@ def task_specific_kwargs(p, model):
if isinstance(p.init_images[0], str):
p.init_images = [helpers.decode_base64_to_image(i, quiet=True) for i in p.init_images]
p.init_images = [i.convert('RGB') if i.mode != 'RGB' else i for i in p.init_images]
- if sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE or len(getattr(p, 'init_images', [])) == 0 and not is_img2img_model:
+ if (sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE or len(getattr(p, 'init_images', [])) == 0) and not is_img2img_model:
p.ops.append('txt2img')
if hasattr(p, 'width') and hasattr(p, 'height'):
task_args = {
@@ -93,12 +94,14 @@ def task_specific_kwargs(p, model):
'target_subject_category': getattr(p, 'prompt', '').split()[-1],
'output_type': 'pil',
}
- debug(f'Diffusers task specific args: {task_args}')
+ if debug_enabled:
+ debug_log(f'Diffusers task specific args: {task_args}')
return task_args
def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2: typing.Optional[list]=None, negative_prompts_2: typing.Optional[list]=None, desc:str='', **kwargs):
t0 = time.time()
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
apply_circular(p.tiling, model)
if hasattr(model, "set_progress_bar_config"):
model.set_progress_bar_config(bar_format='Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m' + desc, ncols=80, colour='#327fba')
@@ -108,7 +111,8 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2
signature = inspect.signature(type(model).__call__, follow_wrapped=True)
possible = list(signature.parameters)
- debug(f'Diffusers pipeline possible: {possible}')
+ if debug_enabled:
+ debug_log(f'Diffusers pipeline possible: {possible}')
prompts, negative_prompts, prompts_2, negative_prompts_2 = fix_prompts(prompts, negative_prompts, prompts_2, negative_prompts_2)
steps = kwargs.get("num_inference_steps", None) or len(getattr(p, 'timesteps', ['1']))
clip_skip = kwargs.pop("clip_skip", 1)
@@ -130,12 +134,13 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2
else:
prompt_parser_diffusers.embedder = None
+ extra_networks.activate(p, include=['text_encoder', 'text_encoder_2', 'text_encoder_3'])
if 'prompt' in possible:
if 'OmniGen' in model.__class__.__name__:
prompts = [p.replace('|image|', '<|image_1|>') for p in prompts]
if hasattr(model, 'text_encoder') and hasattr(model, 'tokenizer') and 'prompt_embeds' in possible and prompt_parser_diffusers.embedder is not None:
args['prompt_embeds'] = prompt_parser_diffusers.embedder('prompt_embeds')
- if 'StableCascade' in model.__class__.__name__ and len(getattr(p, 'negative_pooleds', [])) > 0:
+ if 'StableCascade' in model.__class__.__name__ and prompt_parser_diffusers.embedder is not None:
args['prompt_embeds_pooled'] = prompt_parser_diffusers.embedder('positive_pooleds').unsqueeze(0)
elif 'XL' in model.__class__.__name__ and prompt_parser_diffusers.embedder is not None:
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
@@ -159,6 +164,13 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2
args['negative_prompt'] = negative_prompts[0]
else:
args['negative_prompt'] = negative_prompts
+ if 'complex_human_instruction' in possible:
+ chi = any(len(p) < 300 for p in prompts)
+ p.extra_generation_params["CHI"] = chi
+ if not chi:
+ args['complex_human_instruction'] = None
+ if prompt_parser_diffusers.embedder is not None and not prompt_parser_diffusers.embedder.scheduled_prompt: # not scheduled so we dont need it anymore
+ prompt_parser_diffusers.embedder = None
if 'clip_skip' in possible and parser == 'fixed':
if clip_skip == 1:
@@ -181,6 +193,21 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2
shared.log.error(f'Sampler timesteps: {e}')
else:
shared.log.warning(f'Sampler: sampler={model.scheduler.__class__.__name__} timesteps not supported')
+ if 'sigmas' in possible:
+ sigmas = re.split(',| ', shared.opts.schedulers_timesteps)
+ sigmas = [float(x)/1000.0 for x in sigmas if x.isdigit()]
+ if len(sigmas) > 0:
+ if hasattr(model.scheduler, 'set_timesteps') and "sigmas" in set(inspect.signature(model.scheduler.set_timesteps).parameters.keys()):
+ try:
+ args['sigmas'] = sigmas
+ p.steps = len(sigmas)
+ p.timesteps = sigmas
+ steps = p.steps
+ shared.log.debug(f'Sampler: steps={len(sigmas)} sigmas={sigmas}')
+ except Exception as e:
+ shared.log.error(f'Sampler sigmas: {e}')
+ else:
+ shared.log.warning(f'Sampler: sampler={model.scheduler.__class__.__name__} sigmas not supported')
if hasattr(model, 'scheduler') and hasattr(model.scheduler, 'noise_sampler_seed') and hasattr(model.scheduler, 'noise_sampler'):
model.scheduler.noise_sampler = None # noise needs to be reset instead of using cached values
@@ -248,14 +275,16 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2
if arg in possible:
args[arg] = task_kwargs[arg]
task_args = getattr(p, 'task_args', {})
- debug(f'Diffusers task args: {task_args}')
+ if debug_enabled:
+ debug_log(f'Diffusers task args: {task_args}')
for k, v in task_args.items():
if k in possible:
args[k] = v
else:
- debug(f'Diffusers unknown task args: {k}={v}')
+ debug_log(f'Diffusers unknown task args: {k}={v}')
cross_attention_args = getattr(p, 'cross_attention_kwargs', {})
- debug(f'Diffusers cross-attention args: {cross_attention_args}')
+ if debug_enabled:
+ debug_log(f'Diffusers cross-attention args: {cross_attention_args}')
for k, v in cross_attention_args.items():
if args.get('cross_attention_kwargs', None) is None:
args['cross_attention_kwargs'] = {}
@@ -273,7 +302,7 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2
# handle implicit controlnet
if 'control_image' in possible and 'control_image' not in args and 'image' in args:
- debug('Diffusers: set control image')
+ debug_log('Diffusers: set control image')
args['control_image'] = args['image']
sd_hijack_hypertile.hypertile_set(p, hr=len(getattr(p, 'init_images', [])) > 0)
@@ -291,11 +320,14 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2
clean['negative_prompt'] = len(clean['negative_prompt'])
clean.pop('generator', None)
clean['parser'] = parser
- for k, v in clean.items():
+ for k, v in clean.copy().items():
if isinstance(v, torch.Tensor) or isinstance(v, np.ndarray):
clean[k] = v.shape
if isinstance(v, list) and len(v) > 0 and (isinstance(v[0], torch.Tensor) or isinstance(v[0], np.ndarray)):
clean[k] = [x.shape for x in v]
+ if not debug_enabled and k.endswith('_embeds'):
+ del clean[k]
+ clean['prompt'] = 'embeds'
shared.log.debug(f'Diffuser pipeline: {model.__class__.__name__} task={sd_models.get_diffusers_task(model)} batch={p.iteration + 1}/{p.n_iter}x{p.batch_size} set={clean}')
if p.hdr_clamp or p.hdr_maximize or p.hdr_brightness != 0 or p.hdr_color != 0 or p.hdr_sharpen != 0:
@@ -309,5 +341,6 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2
if shared.cmd_opts.profile:
t1 = time.time()
shared.log.debug(f'Profile: pipeline args: {t1-t0:.2f}')
- debug(f'Diffusers pipeline args: {args}')
+ if debug_enabled:
+ debug_log(f'Diffusers pipeline args: {args}')
return args
diff --git a/modules/processing_callbacks.py b/modules/processing_callbacks.py
index 52ea3e575..0b4c7dfe1 100644
--- a/modules/processing_callbacks.py
+++ b/modules/processing_callbacks.py
@@ -5,14 +5,17 @@
import numpy as np
from modules import shared, processing_correction, extra_networks, timer, prompt_parser_diffusers
+
p = None
-debug_callback = shared.log.trace if os.environ.get('SD_CALLBACK_DEBUG', None) is not None else lambda *args, **kwargs: None
+debug = os.environ.get('SD_CALLBACK_DEBUG', None) is not None
+debug_callback = shared.log.trace if debug else lambda *args, **kwargs: None
def set_callbacks_p(processing):
global p # pylint: disable=global-statement
p = processing
+
def prompt_callback(step, kwargs):
if prompt_parser_diffusers.embedder is None or 'prompt_embeds' not in kwargs:
return kwargs
@@ -27,6 +30,7 @@ def prompt_callback(step, kwargs):
debug_callback(f"Callback: {e}")
return kwargs
+
def diffusers_callback_legacy(step: int, timestep: int, latents: typing.Union[torch.FloatTensor, np.ndarray]):
if p is None:
return
@@ -50,7 +54,8 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}
if p is None:
return kwargs
latents = kwargs.get('latents', None)
- debug_callback(f'Callback: step={step} timestep={timestep} latents={latents.shape if latents is not None else None} kwargs={list(kwargs)}')
+ if debug:
+ debug_callback(f'Callback: step={step} timestep={timestep} latents={latents.shape if latents is not None else None} kwargs={list(kwargs)}')
order = getattr(pipe.scheduler, "order", 1) if hasattr(pipe, 'scheduler') else 1
shared.state.sampling_step = step // order
if shared.state.interrupted or shared.state.skipped:
@@ -61,13 +66,13 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}
if shared.state.interrupted or shared.state.skipped:
raise AssertionError('Interrupted...')
time.sleep(0.1)
- if hasattr(p, "stepwise_lora"):
- extra_networks.activate(p, p.extra_network_data, step=step)
+ if hasattr(p, "stepwise_lora") and shared.native:
+ extra_networks.activate(p, step=step)
if latents is None:
return kwargs
elif shared.opts.nan_skip:
assert not torch.isnan(latents[..., 0, 0]).all(), f'NaN detected at step {step}: Skipping...'
- if len(getattr(p, 'ip_adapter_names', [])) > 0:
+ if len(getattr(p, 'ip_adapter_names', [])) > 0 and p.ip_adapter_names[0] != 'None':
ip_adapter_scales = list(p.ip_adapter_scales)
ip_adapter_starts = list(p.ip_adapter_starts)
ip_adapter_ends = list(p.ip_adapter_ends)
@@ -78,7 +83,7 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}
debug_callback(f"Callback: IP Adapter scales={ip_adapter_scales}")
pipe.set_ip_adapter_scale(ip_adapter_scales)
if step != getattr(pipe, 'num_timesteps', 0):
- kwargs = processing_correction.correction_callback(p, timestep, kwargs)
+ kwargs = processing_correction.correction_callback(p, timestep, kwargs, initial=step == 0)
kwargs = prompt_callback(step, kwargs) # monkey patch for diffusers callback issues
if step == int(getattr(pipe, 'num_timesteps', 100) * p.cfg_end) and 'prompt_embeds' in kwargs and 'negative_prompt_embeds' in kwargs:
if "PAG" in shared.sd_model.__class__.__name__:
@@ -105,7 +110,5 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}
if shared.cmd_opts.profile and shared.profiler is not None:
shared.profiler.step()
t1 = time.time()
- if 'callback' not in timer.process.records:
- timer.process.records['callback'] = 0
- timer.process.records['callback'] += t1 - t0
+ timer.process.add('callback', t1 - t0)
return kwargs
diff --git a/modules/processing_class.py b/modules/processing_class.py
index 79f51576f..97e5f7b6f 100644
--- a/modules/processing_class.py
+++ b/modules/processing_class.py
@@ -31,8 +31,8 @@ def __init__(self,
n_iter: int = 1,
steps: int = 50,
clip_skip: int = 1,
- width: int = 512,
- height: int = 512,
+ width: int = 1024,
+ height: int = 1024,
# samplers
sampler_index: int = None, # pylint: disable=unused-argument # used only to set sampler_name
sampler_name: str = None,
@@ -139,6 +139,7 @@ def __init__(self,
self.negative_pooleds = []
self.disable_extra_networks = False
self.iteration = 0
+ self.network_data = {}
# initializers
self.prompt = prompt
@@ -169,7 +170,7 @@ def __init__(self,
self.image_cfg_scale = image_cfg_scale
self.scale_by = scale_by
self.mask = mask
- self.image_mask = mask # TODO duplciate mask params
+ self.image_mask = mask # TODO processing: remove duplicate mask params
self.latent_mask = latent_mask
self.mask_blur = mask_blur
self.inpainting_fill = inpainting_fill
@@ -338,6 +339,9 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
def close(self):
self.sampler = None # pylint: disable=attribute-defined-outside-init
+ def __str__(self):
+ return f'{self.__class__.__name__}: {self.__dict__}'
+
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def __init__(self, **kwargs):
@@ -347,8 +351,8 @@ def __init__(self, **kwargs):
def init(self, all_prompts=None, all_seeds=None, all_subseeds=None):
if shared.native:
shared.sd_model = sd_models.set_diffuser_pipe(self.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE)
- self.width = self.width or 512
- self.height = self.height or 512
+ self.width = self.width or 1024
+ self.height = self.height or 1024
if all_prompts is not None:
self.all_prompts = all_prompts
if all_seeds is not None:
diff --git a/modules/processing_correction.py b/modules/processing_correction.py
index e715d8c49..050fae889 100644
--- a/modules/processing_correction.py
+++ b/modules/processing_correction.py
@@ -7,9 +7,11 @@
import torch
from modules import shared, sd_vae_taesd, devices
+
debug_enabled = os.environ.get('SD_HDR_DEBUG', None) is not None
debug = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
debug('Trace: HDR')
+skip_correction = False
def sharpen_tensor(tensor, ratio=0):
@@ -116,8 +118,15 @@ def correction(p, timestep, latent):
return latent
-def correction_callback(p, timestep, kwargs):
- if not any([p.hdr_clamp, p.hdr_mode, p.hdr_maximize, p.hdr_sharpen, p.hdr_color, p.hdr_brightness, p.hdr_tint_ratio]):
+def correction_callback(p, timestep, kwargs, initial: bool = False):
+ global skip_correction # pylint: disable=global-statement
+ if initial:
+ if not any([p.hdr_clamp, p.hdr_mode, p.hdr_maximize, p.hdr_sharpen, p.hdr_color, p.hdr_brightness, p.hdr_tint_ratio]):
+ skip_correction = True
+ return kwargs
+ else:
+ skip_correction = False
+ elif skip_correction:
return kwargs
latents = kwargs["latents"]
if debug_enabled:
diff --git a/modules/processing_diffusers.py b/modules/processing_diffusers.py
index 2164134b1..33d875e95 100644
--- a/modules/processing_diffusers.py
+++ b/modules/processing_diffusers.py
@@ -4,10 +4,12 @@
import numpy as np
import torch
import torchvision.transforms.functional as TF
-from modules import shared, devices, processing, sd_models, errors, sd_hijack_hypertile, processing_vae, sd_models_compile, hidiffusion, timer, modelstats
+from PIL import Image
+from modules import shared, devices, processing, sd_models, errors, sd_hijack_hypertile, processing_vae, sd_models_compile, hidiffusion, timer, modelstats, extra_networks
from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled
from modules.processing_args import set_pipeline_args
from modules.onnx_impl import preprocess_pipeline as preprocess_onnx_pipeline, check_parameters_changed as olive_check_parameters_changed
+from modules.lora import networks
debug = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None
@@ -75,7 +77,6 @@ def process_base(p: processing.StableDiffusionProcessing):
clip_skip=p.clip_skip,
desc='Base',
)
- timer.process.record('args')
shared.state.sampling_steps = base_args.get('prior_num_inference_steps', None) or p.steps or base_args.get('num_inference_steps', None)
if shared.opts.scheduler_eta is not None and shared.opts.scheduler_eta > 0 and shared.opts.scheduler_eta < 1:
p.extra_generation_params["Sampler Eta"] = shared.opts.scheduler_eta
@@ -83,11 +84,13 @@ def process_base(p: processing.StableDiffusionProcessing):
try:
t0 = time.time()
sd_models_compile.check_deepcache(enable=True)
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
sd_models.move_model(shared.sd_model, devices.device)
if hasattr(shared.sd_model, 'unet'):
sd_models.move_model(shared.sd_model.unet, devices.device)
if hasattr(shared.sd_model, 'transformer'):
sd_models.move_model(shared.sd_model.transformer, devices.device)
+ extra_networks.activate(p, exclude=['text_encoder', 'text_encoder_2'])
hidiffusion.apply(p, shared.sd_model_type)
# if 'image' in base_args:
# base_args['image'] = set_latents(p)
@@ -199,11 +202,6 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
if hasattr(shared.sd_model, "vae") and output.images is not None and len(output.images) > 0:
output.images = processing_vae.vae_decode(latents=output.images, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.hr_upscale_to_x, height=p.hr_upscale_to_y) # controlnet cannnot deal with latent input
p.task_args['image'] = output.images # replace so hires uses new output
- sd_models.move_model(shared.sd_model, devices.device)
- if hasattr(shared.sd_model, 'unet'):
- sd_models.move_model(shared.sd_model.unet, devices.device)
- if hasattr(shared.sd_model, 'transformer'):
- sd_models.move_model(shared.sd_model.transformer, devices.device)
update_sampler(p, shared.sd_model, second_pass=True)
orig_denoise = p.denoising_strength
p.denoising_strength = strength
@@ -227,11 +225,20 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
shared.state.job = 'HiRes'
shared.state.sampling_steps = hires_args.get('prior_num_inference_steps', None) or p.steps or hires_args.get('num_inference_steps', None)
try:
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
+ sd_models.move_model(shared.sd_model, devices.device)
+ if hasattr(shared.sd_model, 'unet'):
+ sd_models.move_model(shared.sd_model.unet, devices.device)
+ if hasattr(shared.sd_model, 'transformer'):
+ sd_models.move_model(shared.sd_model.transformer, devices.device)
+ if 'base' in p.skip:
+ extra_networks.activate(p)
sd_models_compile.check_deepcache(enable=True)
output = shared.sd_model(**hires_args) # pylint: disable=not-callable
if isinstance(output, dict):
output = SimpleNamespace(**output)
- shared.history.add(output.images, info=processing.create_infotext(p), ops=p.ops)
+ if hasattr(output, 'images'):
+ shared.history.add(output.images, info=processing.create_infotext(p), ops=p.ops)
sd_models_compile.check_deepcache(enable=False)
sd_models_compile.openvino_post_compile(op="base")
except AssertionError as e:
@@ -263,8 +270,7 @@ def process_refine(p: processing.StableDiffusionProcessing, output):
if shared.state.interrupted or shared.state.skipped:
shared.sd_model = orig_pipeline
return output
- if shared.opts.diffusers_offload_mode == "balanced":
- shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
if shared.opts.diffusers_move_refiner:
sd_models.move_model(shared.sd_refiner, devices.device)
if hasattr(shared.sd_refiner, 'unet'):
@@ -313,7 +319,8 @@ def process_refine(p: processing.StableDiffusionProcessing, output):
output = shared.sd_refiner(**refiner_args) # pylint: disable=not-callable
if isinstance(output, dict):
output = SimpleNamespace(**output)
- shared.history.add(output.images, info=processing.create_infotext(p), ops=p.ops)
+ if hasattr(output, 'images'):
+ shared.history.add(output.images, info=processing.create_infotext(p), ops=p.ops)
sd_models_compile.openvino_post_compile(op="refiner")
except AssertionError as e:
shared.log.info(e)
@@ -323,13 +330,6 @@ def process_refine(p: processing.StableDiffusionProcessing, output):
errors.display(e, 'Processing')
modelstats.analyze()
- """ # TODO decode using refiner
- if not shared.state.interrupted and not shared.state.skipped:
- refiner_images = processing_vae.vae_decode(latents=refiner_output.images, model=shared.sd_refiner, full_quality=True, width=max(p.width, p.hr_upscale_to_x), height=max(p.height, p.hr_upscale_to_y))
- for refiner_image in refiner_images:
- results.append(refiner_image)
- """
-
if shared.opts.diffusers_offload_mode == "balanced":
shared.sd_refiner = sd_models.apply_balanced_offload(shared.sd_refiner)
elif shared.opts.diffusers_move_refiner:
@@ -343,35 +343,54 @@ def process_refine(p: processing.StableDiffusionProcessing, output):
def process_decode(p: processing.StableDiffusionProcessing, output):
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, exclude=['vae'])
if output is not None:
if not hasattr(output, 'images') and hasattr(output, 'frames'):
shared.log.debug(f'Generated: frames={len(output.frames[0])}')
output.images = output.frames[0]
+ if output.images is not None and len(output.images) > 0 and isinstance(output.images[0], Image.Image):
+ return output.images
model = shared.sd_model if not is_refiner_enabled(p) else shared.sd_refiner
if not hasattr(model, 'vae'):
if hasattr(model, 'pipe') and hasattr(model.pipe, 'vae'):
model = model.pipe
- if hasattr(model, "vae") and output.images is not None and len(output.images) > 0:
+ if (hasattr(model, "vae") or hasattr(model, "vqgan")) and output.images is not None and len(output.images) > 0:
if p.hr_resize_mode > 0 and (p.hr_upscaler != 'None' or p.hr_resize_mode == 5):
width = max(getattr(p, 'width', 0), getattr(p, 'hr_upscale_to_x', 0))
height = max(getattr(p, 'height', 0), getattr(p, 'hr_upscale_to_y', 0))
else:
width = getattr(p, 'width', 0)
height = getattr(p, 'height', 0)
- results = processing_vae.vae_decode(
- latents = output.images,
- model = model,
- full_quality = p.full_quality,
- width = width,
- height = height,
- )
+ frames = p.task_args.get('num_frames', None)
+ if isinstance(output.images, list):
+ results = []
+ for i in range(len(output.images)):
+ result_batch = processing_vae.vae_decode(
+ latents = output.images[i],
+ model = model,
+ full_quality = p.full_quality,
+ width = width,
+ height = height,
+ frames = frames,
+ )
+ for result in list(result_batch):
+ results.append(result)
+ else:
+ results = processing_vae.vae_decode(
+ latents = output.images,
+ model = model,
+ full_quality = p.full_quality,
+ width = width,
+ height = height,
+ frames = frames,
+ )
elif hasattr(output, 'images'):
results = output.images
else:
- shared.log.warning('Processing returned no results')
+ shared.log.warning('Processing: no results')
results = []
else:
- shared.log.warning('Processing returned no results')
+ shared.log.warning('Processing: no results')
results = []
return results
@@ -425,9 +444,11 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
if p.negative_prompts is None or len(p.negative_prompts) == 0:
p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
- sd_models.move_model(shared.sd_model, devices.device)
sd_models_compile.openvino_recompile_model(p, hires=False, refiner=False) # recompile if a parameter changes
+ if hasattr(p, 'dummy'):
+ images = [Image.new(mode='RGB', size=(p.width, p.height))]
+ return images
if 'base' not in p.skip:
output = process_base(p)
else:
@@ -450,10 +471,16 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
shared.sd_model = orig_pipeline
return results
- results = process_decode(p, output)
+ extra_networks.deactivate(p)
+ timer.process.add('lora', networks.timer.total)
+ networks.timer.clear(complete=True)
+ results = process_decode(p, output)
timer.process.record('decode')
+
shared.sd_model = orig_pipeline
+ # shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
+
if p.state == '':
global last_p # pylint: disable=global-statement
last_p = p
diff --git a/modules/processing_helpers.py b/modules/processing_helpers.py
index ec7fbf048..304e2c211 100644
--- a/modules/processing_helpers.py
+++ b/modules/processing_helpers.py
@@ -1,4 +1,5 @@
import os
+import time
import math
import random
import warnings
@@ -9,7 +10,7 @@
from PIL import Image
from skimage import exposure
from blendmodes.blend import blendLayers, BlendType
-from modules import shared, devices, images, sd_models, sd_samplers, sd_hijack_hypertile, processing_vae
+from modules import shared, devices, images, sd_models, sd_samplers, sd_hijack_hypertile, processing_vae, timer
debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
@@ -158,7 +159,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and (len(seeds) > 1 and shared.opts.enable_batch_seeds or eta_noise_seed_delta > 0):
+ if p is not None and p.sampler is not None and ((len(seeds) > 1 and shared.opts.enable_batch_seeds) or (eta_noise_seed_delta > 0)):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
@@ -352,6 +353,7 @@ def diffusers_image_conditioning(_source_image, latent_image, _image_mask=None):
def validate_sample(tensor):
+ t0 = time.time()
if not isinstance(tensor, np.ndarray) and not isinstance(tensor, torch.Tensor):
return tensor
dtype = tensor.dtype
@@ -366,17 +368,18 @@ def validate_sample(tensor):
sample = 255.0 * np.moveaxis(sample, 0, 2) if not shared.native else 255.0 * sample
with warnings.catch_warnings(record=True) as w:
cast = sample.astype(np.uint8)
- minimum, maximum, mean = np.min(cast), np.max(cast), np.mean(cast)
- if len(w) > 0 or minimum == maximum:
+ if len(w) > 0:
nans = np.isnan(sample).sum()
cast = np.nan_to_num(sample)
cast = cast.astype(np.uint8)
vae = shared.sd_model.vae.dtype if hasattr(shared.sd_model, 'vae') else None
upcast = getattr(shared.sd_model.vae.config, 'force_upcast', None) if hasattr(shared.sd_model, 'vae') and hasattr(shared.sd_model.vae, 'config') else None
- shared.log.error(f'Decode: sample={sample.shape} invalid={nans} mean={mean} dtype={dtype} vae={vae} upcast={upcast} failed to validate')
+ shared.log.error(f'Decode: sample={sample.shape} invalid={nans} dtype={dtype} vae={vae} upcast={upcast} failed to validate')
if upcast is not None and not upcast:
setattr(shared.sd_model.vae.config, 'force_upcast', True) # noqa: B010
shared.log.warning('Decode: upcast=True set, retry operation')
+ t1 = time.time()
+ timer.process.add('validate', t1 - t0)
return cast
@@ -411,7 +414,7 @@ def resize_hires(p, latents): # input=latents output=pil if not latent_upscaler
if latent_upscaler is not None:
return torch.nn.functional.interpolate(latents, size=(p.hr_upscale_to_y // 8, p.hr_upscale_to_x // 8), mode=latent_upscaler["mode"], antialias=latent_upscaler["antialias"])
first_pass_images = processing_vae.vae_decode(latents=latents, model=shared.sd_model, full_quality=p.full_quality, output_type='pil', width=p.width, height=p.height)
- if p.hr_upscale_to_x == 0 or p.hr_upscale_to_y == 0 and hasattr(p, 'init_hr'):
+ if p.hr_upscale_to_x == 0 or (p.hr_upscale_to_y == 0 and hasattr(p, 'init_hr')):
shared.log.error('Hires: missing upscaling dimensions')
return first_pass_images
resized_images = []
@@ -561,7 +564,9 @@ def save_intermediate(p, latents, suffix):
def update_sampler(p, sd_model, second_pass=False):
sampler_selection = p.hr_sampler_name if second_pass else p.sampler_name
if hasattr(sd_model, 'scheduler'):
- if sampler_selection is None or sampler_selection == 'None':
+ if sampler_selection == 'None':
+ return
+ if sampler_selection is None:
sampler = sd_samplers.all_samplers_map.get("UniPC")
else:
sampler = sd_samplers.all_samplers_map.get(sampler_selection, None)
diff --git a/modules/processing_info.py b/modules/processing_info.py
index 714ebf35f..e0fca12ae 100644
--- a/modules/processing_info.py
+++ b/modules/processing_info.py
@@ -140,6 +140,7 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No
if sd_hijack is not None and hasattr(sd_hijack.model_hijack, 'embedding_db') and len(sd_hijack.model_hijack.embedding_db.embeddings_used) > 0: # this is for original hijaacked models only, diffusers are handled separately
args["Embeddings"] = ', '.join(sd_hijack.model_hijack.embedding_db.embeddings_used)
# samplers
+
if getattr(p, 'sampler_name', None) is not None:
args["Sampler eta delta"] = shared.opts.eta_noise_seed_delta if shared.opts.eta_noise_seed_delta != 0 and sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p) else None
args["Sampler eta multiplier"] = p.initial_noise_multiplier if getattr(p, 'initial_noise_multiplier', 1.0) != 1.0 else None
diff --git a/modules/processing_vae.py b/modules/processing_vae.py
index 3c0357c81..95b83daa4 100644
--- a/modules/processing_vae.py
+++ b/modules/processing_vae.py
@@ -33,6 +33,62 @@ def create_latents(image, p, dtype=None, device=None):
return latents
+def full_vqgan_decode(latents, model):
+ t0 = time.time()
+ if model is None or not hasattr(model, 'vqgan'):
+ shared.log.error('VQGAN not found in model')
+ return []
+ if debug:
+ devices.torch_gc(force=True)
+ shared.mem_mon.reset()
+
+ base_device = None
+ if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False):
+ base_device = sd_models.move_base(model, devices.cpu)
+
+ if shared.opts.diffusers_offload_mode == "balanced":
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
+ elif shared.opts.diffusers_offload_mode != "sequential":
+ sd_models.move_model(model.vqgan, devices.device)
+
+ latents = latents.to(devices.device, dtype=model.vqgan.dtype)
+
+ #normalize latents
+ scaling_factor = model.vqgan.config.get("scale_factor", None)
+ if scaling_factor:
+ latents = latents * scaling_factor
+
+ vae_name = os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0] if sd_vae.loaded_vae_file is not None else "default"
+ vae_stats = f'name="{vae_name}" dtype={model.vqgan.dtype} device={model.vqgan.device}'
+ latents_stats = f'shape={latents.shape} dtype={latents.dtype} device={latents.device}'
+ stats = f'vae {vae_stats} latents {latents_stats}'
+
+ log_debug(f'VAE config: {model.vqgan.config}')
+ try:
+ decoded = model.vqgan.decode(latents).sample.clamp(0, 1)
+ except Exception as e:
+ shared.log.error(f'VAE decode: {stats} {e}')
+ errors.display(e, 'VAE decode')
+ decoded = []
+
+ # delete vae after OpenVINO compile
+ if 'VAE' in shared.opts.cuda_compile and shared.opts.cuda_compile_backend == "openvino_fx" and shared.compiled_model_state.first_pass_vae:
+ shared.compiled_model_state.first_pass_vae = False
+ if not shared.opts.openvino_disable_memory_cleanup and hasattr(shared.sd_model, "vqgan"):
+ model.vqgan.apply(sd_models.convert_to_faketensors)
+ devices.torch_gc(force=True)
+
+ if shared.opts.diffusers_offload_mode == "balanced":
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
+ elif shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and base_device is not None:
+ sd_models.move_base(model, base_device)
+ t1 = time.time()
+ if debug:
+ log_debug(f'VAE memory: {shared.mem_mon.read()}')
+ shared.log.debug(f'VAE decode: {stats} time={round(t1-t0, 3)}')
+ return decoded
+
+
def full_vae_decode(latents, model):
t0 = time.time()
if not hasattr(model, 'vae') and hasattr(model, 'pipe'):
@@ -48,8 +104,6 @@ def full_vae_decode(latents, model):
if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False):
base_device = sd_models.move_base(model, devices.cpu)
- if shared.opts.diffusers_offload_mode == "balanced":
- shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
elif shared.opts.diffusers_offload_mode != "sequential":
sd_models.move_model(model.vae, devices.device)
@@ -60,10 +114,11 @@ def full_vae_decode(latents, model):
else: # manual upcast and we restore it later
model.vae.orig_dtype = model.vae.dtype
model.vae = model.vae.to(dtype=torch.float32)
- latents = latents.to(torch.float32)
latents = latents.to(devices.device)
if getattr(model.vae, "post_quant_conv", None) is not None:
latents = latents.to(next(iter(model.vae.post_quant_conv.parameters())).dtype)
+ else:
+ latents = latents.to(model.vae.dtype)
# normalize latents
latents_mean = model.vae.config.get("latents_mean", None)
@@ -86,10 +141,12 @@ def full_vae_decode(latents, model):
log_debug(f'VAE config: {model.vae.config}')
try:
- decoded = model.vae.decode(latents, return_dict=False)[0]
+ with devices.inference_context():
+ decoded = model.vae.decode(latents, return_dict=False)[0]
except Exception as e:
shared.log.error(f'VAE decode: {stats} {e}')
- errors.display(e, 'VAE decode')
+ if 'out of memory' not in str(e):
+ errors.display(e, 'VAE decode')
decoded = []
if hasattr(model.vae, "orig_dtype"):
@@ -103,8 +160,6 @@ def full_vae_decode(latents, model):
model.vae.apply(sd_models.convert_to_faketensors)
devices.torch_gc(force=True)
- if shared.opts.diffusers_offload_mode == "balanced":
- shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
elif shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and base_device is not None:
sd_models.move_base(model, base_device)
t1 = time.time()
@@ -147,7 +202,7 @@ def taesd_vae_encode(image):
return encoded
-def vae_decode(latents, model, output_type='np', full_quality=True, width=None, height=None):
+def vae_decode(latents, model, output_type='np', full_quality=True, width=None, height=None, frames=None):
t0 = time.time()
model = model or shared.sd_model
if not hasattr(model, 'vae') and hasattr(model, 'pipe'):
@@ -161,11 +216,15 @@ def vae_decode(latents, model, output_type='np', full_quality=True, width=None,
return []
if shared.state.interrupted or shared.state.skipped:
return []
- if not hasattr(model, 'vae'):
+ if not hasattr(model, 'vae') and not hasattr(model, 'vqgan'):
shared.log.error('VAE not found in model')
return []
- if hasattr(model, "_unpack_latents") and hasattr(model, "vae_scale_factor") and width is not None and height is not None: # FLUX
+ if hasattr(model, '_unpack_latents') and hasattr(model, 'transformer_spatial_patch_size') and frames is not None: # LTX
+ latent_num_frames = (frames - 1) // model.vae_temporal_compression_ratio + 1
+ latents = model._unpack_latents(latents.unsqueeze(0), latent_num_frames, height // 32, width // 32, model.transformer_spatial_patch_size, model.transformer_temporal_patch_size) # pylint: disable=protected-access
+ latents = model._denormalize_latents(latents, model.vae.latents_mean, model.vae.latents_std, model.vae.config.scaling_factor) # pylint: disable=protected-access
+ if hasattr(model, '_unpack_latents') and hasattr(model, "vae_scale_factor") and width is not None and height is not None: # FLUX
latents = model._unpack_latents(latents, height, width, model.vae_scale_factor) # pylint: disable=protected-access
if len(latents.shape) == 3: # lost a batch dim in hires
latents = latents.unsqueeze(0)
@@ -176,12 +235,20 @@ def vae_decode(latents, model, output_type='np', full_quality=True, width=None,
decoded = latents.float().cpu().numpy()
elif full_quality and hasattr(model, "vae"):
decoded = full_vae_decode(latents=latents, model=model)
+ elif hasattr(model, "vqgan"):
+ decoded = full_vqgan_decode(latents=latents, model=model)
else:
decoded = taesd_vae_decode(latents=latents)
if torch.is_tensor(decoded):
- if hasattr(model, 'image_processor'):
+ if hasattr(model, 'video_processor'):
+ imgs = model.video_processor.postprocess_video(decoded, output_type='pil')
+ elif hasattr(model, 'image_processor'):
imgs = model.image_processor.postprocess(decoded, output_type=output_type)
+ elif hasattr(model, "vqgan"):
+ imgs = decoded.permute(0, 2, 3, 1).cpu().float().numpy()
+ if output_type == "pil":
+ imgs = model.numpy_to_pil(imgs)
else:
import diffusers
model.image_processor = diffusers.image_processor.VaeImageProcessor()
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 3a1288097..1eb7a77ee 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -147,7 +147,7 @@ def __default__(self, data, children, meta):
def get_schedule(prompt):
try:
tree = schedule_parser.parse(prompt)
- except lark.exceptions.LarkError:
+ except Exception:
return [[steps, prompt]]
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
diff --git a/modules/prompt_parser_diffusers.py b/modules/prompt_parser_diffusers.py
index 234272907..4e31c747a 100644
--- a/modules/prompt_parser_diffusers.py
+++ b/modules/prompt_parser_diffusers.py
@@ -16,6 +16,7 @@
token_dict = None # used by helper get_tokens
token_type = None # used by helper get_tokens
cache = OrderedDict()
+last_attention = None
embedder = None
@@ -38,8 +39,6 @@ def prepare_model(pipe = None):
pipe = pipe.pipe
if not hasattr(pipe, "text_encoder"):
return None
- if shared.opts.diffusers_offload_mode == "balanced":
- pipe = sd_models.apply_balanced_offload(pipe)
elif hasattr(pipe, "maybe_free_model_hooks"):
pipe.maybe_free_model_hooks()
devices.torch_gc()
@@ -52,7 +51,7 @@ def __init__(self, prompts, negative_prompts, steps, clip_skip, p):
self.prompts = prompts
self.negative_prompts = negative_prompts
self.batchsize = len(self.prompts)
- self.attention = None
+ self.attention = last_attention
self.allsame = self.compare_prompts() # collapses batched prompts to single prompt if possible
self.steps = steps
self.clip_skip = clip_skip
@@ -64,6 +63,8 @@ def __init__(self, prompts, negative_prompts, steps, clip_skip, p):
self.positive_schedule = None
self.negative_schedule = None
self.scheduled_prompt = False
+ if hasattr(p, 'dummy'):
+ return
earlyout = self.checkcache(p)
if earlyout:
return
@@ -93,7 +94,7 @@ def flatten(xss):
return [x for xs in xss for x in xs]
# unpack EN data in case of TE LoRA
- en_data = p.extra_network_data
+ en_data = p.network_data
en_data = [idx.items for item in en_data.values() for idx in item]
effective_batch = 1 if self.allsame else self.batchsize
key = str([self.prompts, self.negative_prompts, effective_batch, self.clip_skip, self.steps, en_data])
@@ -113,6 +114,7 @@ def flatten(xss):
debug(f"Prompt cache: add={key}")
while len(cache) > int(shared.opts.sd_textencoder_cache_size):
cache.popitem(last=False)
+ return True
if item:
self.__dict__.update(cache[key])
cache.move_to_end(key)
@@ -161,8 +163,10 @@ def extend_embeds(self, batchidx, idx): # Extends scheduled prompt via index
self.negative_pooleds[batchidx].append(self.negative_pooleds[batchidx][idx])
def encode(self, pipe, positive_prompt, negative_prompt, batchidx):
+ global last_attention # pylint: disable=global-statement
self.attention = shared.opts.prompt_attention
- if self.attention == "xhinker" or 'Flux' in pipe.__class__.__name__:
+ last_attention = self.attention
+ if self.attention == "xhinker":
prompt_embed, positive_pooled, negative_embed, negative_pooled = get_xhinker_text_embeddings(pipe, positive_prompt, negative_prompt, self.clip_skip)
else:
prompt_embed, positive_pooled, negative_embed, negative_pooled = get_weighted_text_embeddings(pipe, positive_prompt, negative_prompt, self.clip_skip)
@@ -178,7 +182,6 @@ def encode(self, pipe, positive_prompt, negative_prompt, batchidx):
if debug_enabled:
get_tokens(pipe, 'positive', positive_prompt)
get_tokens(pipe, 'negative', negative_prompt)
- pipe = prepare_model()
def __call__(self, key, step=0):
batch = getattr(self, key)
@@ -194,8 +197,6 @@ def __call__(self, key, step=0):
def compel_hijack(self, token_ids: torch.Tensor, attention_mask: typing.Optional[torch.Tensor] = None) -> torch.Tensor:
- if not devices.same_device(self.text_encoder.device, devices.device):
- sd_models.move_model(self.text_encoder, devices.device)
needs_hidden_states = self.returned_embeddings_type != 1
text_encoder_output = self.text_encoder(token_ids, attention_mask, output_hidden_states=needs_hidden_states, return_dict=True)
@@ -372,25 +373,31 @@ def prepare_embedding_providers(pipe, clip_skip) -> list[EmbeddingsProvider]:
embedding_type = -(clip_skip + 1)
else:
embedding_type = clip_skip
+ embedding_args = {
+ 'truncate': False,
+ 'returned_embeddings_type': embedding_type,
+ 'device': device,
+ 'dtype_for_device_getter': lambda device: devices.dtype,
+ }
if getattr(pipe, "prior_pipe", None) is not None and getattr(pipe.prior_pipe, "tokenizer", None) is not None and getattr(pipe.prior_pipe, "text_encoder", None) is not None:
- provider = EmbeddingsProvider(padding_attention_mask_value=0, tokenizer=pipe.prior_pipe.tokenizer, text_encoder=pipe.prior_pipe.text_encoder, truncate=False, returned_embeddings_type=embedding_type, device=device)
+ provider = EmbeddingsProvider(padding_attention_mask_value=0, tokenizer=pipe.prior_pipe.tokenizer, text_encoder=pipe.prior_pipe.text_encoder, **embedding_args)
embeddings_providers.append(provider)
- no_mask_provider = EmbeddingsProvider(padding_attention_mask_value=1 if "sote" in pipe.sd_checkpoint_info.name.lower() else 0, tokenizer=pipe.prior_pipe.tokenizer, text_encoder=pipe.prior_pipe.text_encoder, truncate=False, returned_embeddings_type=embedding_type, device=device)
+ no_mask_provider = EmbeddingsProvider(padding_attention_mask_value=1 if "sote" in pipe.sd_checkpoint_info.name.lower() else 0, tokenizer=pipe.prior_pipe.tokenizer, text_encoder=pipe.prior_pipe.text_encoder, **embedding_args)
embeddings_providers.append(no_mask_provider)
elif getattr(pipe, "tokenizer", None) is not None and getattr(pipe, "text_encoder", None) is not None:
- if not devices.same_device(pipe.text_encoder.device, devices.device):
- sd_models.move_model(pipe.text_encoder, devices.device)
- provider = EmbeddingsProvider(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, truncate=False, returned_embeddings_type=embedding_type, device=device)
+ if pipe.text_encoder.__class__.__name__.startswith('CLIP'):
+ sd_models.move_model(pipe.text_encoder, devices.device, force=True)
+ provider = EmbeddingsProvider(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, **embedding_args)
embeddings_providers.append(provider)
if getattr(pipe, "tokenizer_2", None) is not None and getattr(pipe, "text_encoder_2", None) is not None:
- if not devices.same_device(pipe.text_encoder_2.device, devices.device):
- sd_models.move_model(pipe.text_encoder_2, devices.device)
- provider = EmbeddingsProvider(tokenizer=pipe.tokenizer_2, text_encoder=pipe.text_encoder_2, truncate=False, returned_embeddings_type=embedding_type, device=device)
+ if pipe.text_encoder_2.__class__.__name__.startswith('CLIP'):
+ sd_models.move_model(pipe.text_encoder_2, devices.device, force=True)
+ provider = EmbeddingsProvider(tokenizer=pipe.tokenizer_2, text_encoder=pipe.text_encoder_2, **embedding_args)
embeddings_providers.append(provider)
if getattr(pipe, "tokenizer_3", None) is not None and getattr(pipe, "text_encoder_3", None) is not None:
- if not devices.same_device(pipe.text_encoder_3.device, devices.device):
- sd_models.move_model(pipe.text_encoder_3, devices.device)
- provider = EmbeddingsProvider(tokenizer=pipe.tokenizer_3, text_encoder=pipe.text_encoder_3, truncate=False, returned_embeddings_type=embedding_type, device=device)
+ if pipe.text_encoder_3.__class__.__name__.startswith('CLIP'):
+ sd_models.move_model(pipe.text_encoder_3, devices.device, force=True)
+ provider = EmbeddingsProvider(tokenizer=pipe.tokenizer_3, text_encoder=pipe.text_encoder_3, **embedding_args)
embeddings_providers.append(provider)
return embeddings_providers
@@ -583,15 +590,15 @@ def get_xhinker_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", cl
te1_device, te2_device, te3_device = None, None, None
if hasattr(pipe, "text_encoder") and pipe.text_encoder.device != devices.device:
te1_device = pipe.text_encoder.device
- sd_models.move_model(pipe.text_encoder, devices.device)
+ sd_models.move_model(pipe.text_encoder, devices.device, force=True)
if hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2.device != devices.device:
te2_device = pipe.text_encoder_2.device
- sd_models.move_model(pipe.text_encoder_2, devices.device)
+ sd_models.move_model(pipe.text_encoder_2, devices.device, force=True)
if hasattr(pipe, "text_encoder_3") and pipe.text_encoder_3.device != devices.device:
te3_device = pipe.text_encoder_3.device
- sd_models.move_model(pipe.text_encoder_3, devices.device)
+ sd_models.move_model(pipe.text_encoder_3, devices.device, force=True)
- if is_sd3:
+ if 'StableDiffusion3' in pipe.__class__.__name__:
prompt_embed, negative_embed, positive_pooled, negative_pooled = get_weighted_text_embeddings_sd3(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, use_t5_encoder=bool(pipe.text_encoder_3))
elif 'Flux' in pipe.__class__.__name__:
prompt_embed, positive_pooled = get_weighted_text_embeddings_flux1(pipe=pipe, prompt=prompt, prompt2=prompt_2, device=devices.device)
@@ -601,10 +608,10 @@ def get_xhinker_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", cl
prompt_embed, negative_embed = get_weighted_text_embeddings_sd15(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, clip_skip=clip_skip)
if te1_device is not None:
- sd_models.move_model(pipe.text_encoder, te1_device)
+ sd_models.move_model(pipe.text_encoder, te1_device, force=True)
if te2_device is not None:
- sd_models.move_model(pipe.text_encoder_2, te1_device)
+ sd_models.move_model(pipe.text_encoder_2, te1_device, force=True)
if te3_device is not None:
- sd_models.move_model(pipe.text_encoder_3, te1_device)
+ sd_models.move_model(pipe.text_encoder_3, te1_device, force=True)
return prompt_embed, positive_pooled, negative_embed, negative_pooled
diff --git a/modules/pulid/eva_clip/hf_model.py b/modules/pulid/eva_clip/hf_model.py
index c4b9fd85b..d148bbff2 100644
--- a/modules/pulid/eva_clip/hf_model.py
+++ b/modules/pulid/eva_clip/hf_model.py
@@ -31,7 +31,6 @@ class PretrainedConfig:
def _camel2snake(s):
return re.sub(r'(?DDIM
+class DDIMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.Tensor
+ pred_original_sample: Optional[torch.Tensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.Tensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.Tensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+class BDIA_DDIMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True, #was True
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading", #leading
+ rescale_betas_zero_snr: bool = False,
+ gamma: float = 1.0,
+
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas #may have to add something for last step
+
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+ self.next_sample = []
+ self.BDIA = False
+
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
+ .round()[::-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
+ )
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: int,
+ sample: torch.Tensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ debug: bool = False,
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE.
+
+ Args:
+ model_output (torch.Tensor): Direct output from learned diffusion model
+ timestep (int): Current discrete timestep in the diffusion chain
+ sample (torch.Tensor): Current instance of sample created by diffusion process
+ eta (float): Weight of noise for added noise in diffusion step
+ use_clipped_model_output (bool): Whether to use clipped model output
+ generator (torch.Generator, optional): Random number generator
+ variance_noise (torch.Tensor, optional): Pre-generated noise for variance
+ return_dict (bool): Whether to return as DDIMSchedulerOutput or tuple
+ debug (bool): Whether to print debug information
+ """
+ if self.num_inference_steps is None:
+ raise ValueError("Number of inference steps is 'None', run 'set_timesteps' first")
+
+ # Calculate timesteps
+ step_size = self.config.num_train_timesteps // self.num_inference_steps
+ prev_timestep = timestep - step_size
+ next_timestep = timestep + step_size
+
+ if debug:
+ print("\n=== Timestep Information ===")
+ print(f"Current timestep: {timestep}")
+ print(f"Previous timestep: {prev_timestep}")
+ print(f"Next timestep: {next_timestep}")
+ print(f"Step size: {step_size}")
+
+ # Pre-compute alpha and variance values
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** 0.5
+
+ # Compute required values
+ alpha_i = alpha_prod_t ** 0.5
+ alpha_i_minus_1 = alpha_prod_t_prev ** 0.5
+ sigma_i = (1 - alpha_prod_t) ** 0.5
+ sigma_i_minus_1 = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5
+
+ if debug:
+ print("\n=== Alpha Values ===")
+ print(f"alpha_i: {alpha_i}")
+ print(f"alpha_i_minus_1: {alpha_i_minus_1}")
+ print(f"sigma_i: {sigma_i}")
+ print(f"sigma_i_minus_1: {sigma_i_minus_1}")
+
+ # Predict original sample based on prediction type
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - sigma_i * model_output) / alpha_i
+ pred_epsilon = model_output
+ if debug:
+ print("\nPrediction type: epsilon")
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ pred_epsilon = (sample - alpha_i * pred_original_sample) / sigma_i
+ if debug:
+ print("\nPrediction type: sample")
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = alpha_i * sample - sigma_i * model_output
+ pred_epsilon = alpha_i * model_output + sigma_i * sample
+ if debug:
+ print("\nPrediction type: v_prediction")
+ else:
+ raise ValueError(
+ f"prediction_type {self.config.prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`"
+ )
+
+ # Apply thresholding or clipping if configured
+ if self.config.thresholding:
+ if debug:
+ print("\nApplying thresholding")
+ pred_original_sample = self._threshold_sample(pred_original_sample)
+ elif self.config.clip_sample:
+ if debug:
+ print("\nApplying clipping")
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # Recompute pred_epsilon if using clipped model output
+ if use_clipped_model_output:
+ if debug:
+ print("\nUsing clipped model output")
+ pred_epsilon = (sample - alpha_i * pred_original_sample) / sigma_i
+
+ # Compute DDIM step
+ ddim_step = alpha_i_minus_1 * pred_original_sample + sigma_i_minus_1 * pred_epsilon
+
+ # Handle initial DDIM step or BDIA steps
+ if len(self.next_sample) == 0:
+ if debug:
+ print("\nFirst iteration (DDIM)")
+ self.update_next_sample_BDIA(sample)
+ self.update_next_sample_BDIA(ddim_step)
+ else:
+ if debug:
+ print("\nBDIA step")
+ # BDIA implementation
+ alpha_prod_t_next = self.alphas_cumprod[next_timestep]
+ alpha_i_plus_1 = alpha_prod_t_next ** 0.5
+ sigma_i_plus_1 = (1 - alpha_prod_t_next) ** 0.5
+
+ if debug:
+ print(f"alpha_i_plus_1: {alpha_i_plus_1}")
+ print(f"sigma_i_plus_1: {sigma_i_plus_1}")
+
+ a = alpha_i_plus_1 * pred_original_sample + sigma_i_plus_1 * pred_epsilon
+ bdia_step = (
+ self.config.gamma * self.next_sample[-2] +
+ ddim_step -
+ (self.config.gamma * a)
+ )
+ self.update_next_sample_BDIA(bdia_step)
+
+ prev_sample = self.next_sample[-1]
+
+ # Apply variance noise if eta > 0
+ if eta > 0:
+ if debug:
+ print(f"\nApplying variance noise with eta: {eta}")
+
+ if variance_noise is not None and generator is not None:
+ raise ValueError(
+ "Cannot pass both generator and variance_noise. Use either `generator` or `variance_noise`."
+ )
+
+ if variance_noise is None:
+ variance_noise = randn_tensor(
+ model_output.shape,
+ generator=generator,
+ device=model_output.device,
+ dtype=model_output.dtype
+ )
+ prev_sample = prev_sample + std_dev_t * variance_noise
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
+ # for the subsequent add_noise calls
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def update_next_sample_BDIA(self, new_value):
+ self.next_sample.append(new_value.clone())
+
+
+ def __len__(self):
+ return self.config.num_train_timesteps
\ No newline at end of file
diff --git a/modules/schedulers/scheduler_dpm_flowmatch.py b/modules/schedulers/scheduler_dpm_flowmatch.py
index 1afe54498..69452aca9 100644
--- a/modules/schedulers/scheduler_dpm_flowmatch.py
+++ b/modules/schedulers/scheduler_dpm_flowmatch.py
@@ -22,7 +22,8 @@ def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get("w0", torch.zeros_like(x))
if seed is None:
- seed = torch.randint(0, 2**63 - 1, []).item()
+ seed = [torch.randint(0, 2**63 - 1, []).item()]
+ seed = [s.initial_seed() if isinstance(s, torch.Generator) else s for s in seed]
self.batched = True
try:
assert len(seed) == x.shape[0]
diff --git a/modules/schedulers/scheduler_tcd.py b/modules/schedulers/scheduler_tcd.py
index bff12ab5a..9b2d4d35a 100644
--- a/modules/schedulers/scheduler_tcd.py
+++ b/modules/schedulers/scheduler_tcd.py
@@ -447,7 +447,6 @@ def set_timesteps(
init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps)
t_start = max(self.num_inference_steps - init_timestep, 0)
timesteps = timesteps[t_start * self.order :]
- # TODO: also reset self.num_inference_steps?
else:
# 2.2 Create the "standard" TCD inference timestep schedule.
if num_inference_steps > self.config.num_train_timesteps:
diff --git a/modules/schedulers/scheduler_ufogen.py b/modules/schedulers/scheduler_ufogen.py
new file mode 100644
index 000000000..f4d8aee97
--- /dev/null
+++ b/modules/schedulers/scheduler_ufogen.py
@@ -0,0 +1,523 @@
+# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim and https://github.com/xuyanwu/SIDDMs-UFOGen
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UFOGen
+class UFOGenSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class UFOGenScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `UFOGenScheduler` implements multistep and onestep sampling for a UFOGen model, introduced in
+ [UFOGen: You Forward Once Large Scale Text-to-Image Generation via Diffusion GANs](https://arxiv.org/abs/2311.09257)
+ by Yanwu Xu, Yang Zhao, Zhisheng Xiao, and Tingbo Hou. UFOGen is a varianet of the denoising diffusion GAN (DDGAN)
+ model designed for one-step sampling.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
+ Diffusion.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ denoising_step_size (`int`, defaults to 250):
+ The denoising step size parameter from the UFOGen paper. The number of steps used for training is roughly
+ `math.ceil(num_train_timesteps / denoising_step_size)`.
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ steps_offset: int = 0,
+ rescale_betas_zero_snr: bool = False,
+ denoising_step_size: int = 250,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ elif beta_schedule == "sigmoid":
+ # GeoDiff sigmoid schedule
+ betas = torch.linspace(-6, 6, num_train_timesteps)
+ self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.custom_timesteps = False
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.FloatTensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: Optional[int] = None,
+ device: Union[str, torch.device] = None,
+ timesteps: Optional[List[int]] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed,
+ `num_inference_steps` must be `None`.
+
+ """
+ if num_inference_steps is not None and timesteps is not None:
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
+
+ if timesteps is not None:
+ for i in range(1, len(timesteps)):
+ if timesteps[i] >= timesteps[i - 1]:
+ raise ValueError("`custom_timesteps` must be in descending order.")
+
+ if timesteps[0] >= self.config.num_train_timesteps:
+ raise ValueError(
+ f"`timesteps` must start before `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps}."
+ )
+
+ timesteps = np.array(timesteps, dtype=np.int64)
+ self.custom_timesteps = True
+ else:
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+ self.custom_timesteps = False
+
+ if num_inference_steps == 1:
+ # Set the timestep schedule to num_train_timesteps - 1 rather than 0
+ # (that is, the one-step timestep schedule is always trailing rather than leading or linspace)
+ timesteps = np.array([self.config.num_train_timesteps - 1], dtype=np.int64)
+ else:
+ if self.config.timestep_spacing == "linspace":
+ timesteps = (
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
+ .round()[::-1]
+ .copy()
+ .astype(np.int64)
+ )
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[UFOGenSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_ufogen.UFOGenSchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_ddpm.UFOGenSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_ufogen.UFOGenSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ # 0. Resolve timesteps
+ t = timestep
+ prev_t = self.previous_timestep(t)
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ # beta_prod_t_prev = 1 - alpha_prod_t_prev
+ # current_alpha_t = alpha_prod_t / alpha_prod_t_prev
+ # current_beta_t = 1 - current_alpha_t
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
+ " `v_prediction` for UFOGenScheduler."
+ )
+
+ # 3. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ pred_original_sample = self._threshold_sample(pred_original_sample)
+ elif self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 4. Single-step or multi-step sampling
+ # Noise is not used on the final timestep of the timestep schedule.
+ # This also means that noise is not used for one-step sampling.
+ if t != self.timesteps[-1]:
+ device = model_output.device
+ noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
+ sqrt_alpha_prod_t_prev = alpha_prod_t_prev**0.5
+ sqrt_one_minus_alpha_prod_t_prev = (1 - alpha_prod_t_prev) ** 0.5
+ pred_prev_sample = sqrt_alpha_prod_t_prev * pred_original_sample + sqrt_one_minus_alpha_prod_t_prev * noise
+ else:
+ # Simply return the pred_original_sample. If `prediction_type == "sample"`, this is equivalent to returning
+ # the output of the GAN generator U-Net on the initial noisy latents x_T ~ N(0, I).
+ pred_prev_sample = pred_original_sample
+
+ if not return_dict:
+ return (pred_prev_sample,)
+
+ return UFOGenSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
+ def previous_timestep(self, timestep):
+ if self.custom_timesteps:
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
+ if index == self.timesteps.shape[0] - 1:
+ prev_t = torch.tensor(-1)
+ else:
+ prev_t = self.timesteps[index + 1]
+ else:
+ num_inference_steps = (
+ self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
+ )
+ prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
+
+ return prev_t
\ No newline at end of file
diff --git a/modules/schedulers/scheduler_vdm.py b/modules/schedulers/scheduler_vdm.py
index 543b29ff3..492c30a0c 100644
--- a/modules/schedulers/scheduler_vdm.py
+++ b/modules/schedulers/scheduler_vdm.py
@@ -147,7 +147,7 @@ def __init__(
self.timesteps = torch.from_numpy(self.get_timesteps(len(self)))
if num_train_timesteps:
alphas_cumprod = self.alphas_cumprod(torch.flip(self.timesteps, dims=(0,)))
- alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] # TODO: Might not be exact
+ alphas = alphas_cumprod[1:] / alphas_cumprod[:-1]
self.alphas = torch.cat([alphas_cumprod[:1], alphas])
self.betas = 1 - self.alphas
diff --git a/modules/scripts.py b/modules/scripts.py
index cf2cf25b9..9c410a3b6 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -288,6 +288,7 @@ def register_scripts_from_module(module, scriptfile):
current_basedir = paths.script_path
t.record(os.path.basename(scriptfile.basedir) if scriptfile.basedir != paths.script_path else scriptfile.filename)
sys.path = syspath
+
global scripts_txt2img, scripts_img2img, scripts_control, scripts_postproc # pylint: disable=global-statement
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
@@ -329,6 +330,7 @@ def __init__(self):
self.scripts = []
self.selectable_scripts = []
self.alwayson_scripts = []
+ self.auto_processing_scripts = []
self.titles = []
self.infotext_fields = []
self.paste_field_names = []
@@ -336,6 +338,31 @@ def __init__(self):
self.is_img2img = False
self.inputs = [None]
+ def add_script(self, script_class, path, is_img2img, is_control):
+ try:
+ script = script_class()
+ script.filename = path
+ script.is_txt2img = not is_img2img
+ script.is_img2img = is_img2img
+ if is_control: # this is messy but show is a legacy function that is not aware of control tab
+ v1 = script.show(script.is_txt2img)
+ v2 = script.show(script.is_img2img)
+ if v1 == AlwaysVisible or v2 == AlwaysVisible:
+ visibility = AlwaysVisible
+ else:
+ visibility = v1 or v2
+ else:
+ visibility = script.show(script.is_img2img)
+ if visibility == AlwaysVisible:
+ self.scripts.append(script)
+ self.alwayson_scripts.append(script)
+ script.alwayson = True
+ elif visibility:
+ self.scripts.append(script)
+ self.selectable_scripts.append(script)
+ except Exception as e:
+ errors.log.error(f'Script initialize: {path} {e}')
+
def initialize_scripts(self, is_img2img=False, is_control=False):
from modules import scripts_auto_postprocessing
@@ -350,34 +377,14 @@ def initialize_scripts(self, is_img2img=False, is_control=False):
self.scripts.clear()
self.alwayson_scripts.clear()
self.selectable_scripts.clear()
- auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
+ self.auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
- all_scripts = auto_processing_scripts + scripts_data
- sorted_scripts = sorted(all_scripts, key=lambda x: x.script_class().title().lower())
+ sorted_scripts = sorted(scripts_data, key=lambda x: x.script_class().title().lower())
for script_class, path, _basedir, _script_module in sorted_scripts:
- try:
- script = script_class()
- script.filename = path
- script.is_txt2img = not is_img2img
- script.is_img2img = is_img2img
- if is_control: # this is messy but show is a legacy function that is not aware of control tab
- v1 = script.show(script.is_txt2img)
- v2 = script.show(script.is_img2img)
- if v1 == AlwaysVisible or v2 == AlwaysVisible:
- visibility = AlwaysVisible
- else:
- visibility = v1 or v2
- else:
- visibility = script.show(script.is_img2img)
- if visibility == AlwaysVisible:
- self.scripts.append(script)
- self.alwayson_scripts.append(script)
- script.alwayson = True
- elif visibility:
- self.scripts.append(script)
- self.selectable_scripts.append(script)
- except Exception as e:
- errors.log.error(f'Script initialize: {path} {e}')
+ self.add_script(script_class, path, is_img2img, is_control)
+ sorted_scripts = sorted(self.auto_processing_scripts, key=lambda x: x.script_class().title().lower())
+ for script_class, path, _basedir, _script_module in sorted_scripts:
+ self.add_script(script_class, path, is_img2img, is_control)
def prepare_ui(self):
self.inputs = [None]
diff --git a/modules/sd_checkpoint.py b/modules/sd_checkpoint.py
index afc5842e4..6ab396329 100644
--- a/modules/sd_checkpoint.py
+++ b/modules/sd_checkpoint.py
@@ -123,13 +123,17 @@ def list_models():
checkpoint_aliases.clear()
ext_filter = [".safetensors"] if shared.opts.sd_disable_ckpt or shared.native else [".ckpt", ".safetensors"]
model_list = list(modelloader.load_models(model_path=model_path, model_url=None, command_path=shared.opts.ckpt_dir, ext_filter=ext_filter, download_name=None, ext_blacklist=[".vae.ckpt", ".vae.safetensors"]))
+ safetensors_list = []
for filename in sorted(model_list, key=str.lower):
checkpoint_info = CheckpointInfo(filename)
+ safetensors_list.append(checkpoint_info)
if checkpoint_info.name is not None:
checkpoint_info.register()
+ diffusers_list = []
if shared.native:
for repo in modelloader.load_diffusers_models(clear=True):
checkpoint_info = CheckpointInfo(repo['name'], sha=repo['hash'])
+ diffusers_list.append(checkpoint_info)
if checkpoint_info.name is not None:
checkpoint_info.register()
if shared.cmd_opts.ckpt is not None:
@@ -143,7 +147,7 @@ def list_models():
shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
elif shared.cmd_opts.ckpt != shared.default_sd_model_file and shared.cmd_opts.ckpt is not None:
shared.log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found')
- shared.log.info(f'Available Models: path="{shared.opts.ckpt_dir}" items={len(checkpoints_list)} time={time.time()-t0:.2f}')
+ shared.log.info(f'Available Models: items={len(checkpoints_list)} safetensors="{shared.opts.ckpt_dir}":{len(safetensors_list)} diffusers="{shared.opts.diffusers_dir}":{len(diffusers_list)} time={time.time()-t0:.2f}')
checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename))
def update_model_hashes():
@@ -168,7 +172,10 @@ def update_model_hashes():
def get_closet_checkpoint_match(s: str):
if s.startswith('https://huggingface.co/'):
- s = s.replace('https://huggingface.co/', '')
+ model_name = s.replace('https://huggingface.co/', '')
+ checkpoint_info = CheckpointInfo(model_name) # create a virutal model info
+ checkpoint_info.type = 'huggingface'
+ return checkpoint_info
if s.startswith('huggingface/'):
model_name = s.replace('huggingface/', '')
checkpoint_info = CheckpointInfo(model_name) # create a virutal model info
@@ -185,6 +192,11 @@ def get_closet_checkpoint_match(s: str):
if found and len(found) == 1:
return found[0]
+ # absolute path
+ if s.endswith('.safetensors') and os.path.isfile(s):
+ checkpoint_info = CheckpointInfo(s)
+ return checkpoint_info
+
# reference search
"""
found = sorted([info for info in shared.reference_models.values() if os.path.basename(info['path']).lower().startswith(s.lower())], key=lambda x: len(x['path']))
@@ -198,8 +210,9 @@ def get_closet_checkpoint_match(s: str):
if shared.opts.sd_checkpoint_autodownload and s.count('/') == 1:
modelloader.hf_login()
found = modelloader.find_diffuser(s, full=True)
+ found = [f for f in found if f == s]
shared.log.info(f'HF search: model="{s}" results={found}')
- if found is not None and len(found) == 1 and found[0] == s:
+ if found is not None and len(found) == 1:
checkpoint_info = CheckpointInfo(s)
checkpoint_info.type = 'huggingface'
return checkpoint_info
@@ -240,6 +253,8 @@ def select_checkpoint(op='model'):
model_checkpoint = shared.opts.data.get('sd_model_refiner', None)
else:
model_checkpoint = shared.opts.sd_model_checkpoint
+ if len(model_checkpoint) < 3:
+ return None
if model_checkpoint is None or model_checkpoint == 'None':
return None
checkpoint_info = get_closet_checkpoint_match(model_checkpoint)
@@ -266,6 +281,12 @@ def select_checkpoint(op='model'):
return checkpoint_info
+def init_metadata():
+ global sd_metadata # pylint: disable=global-statement
+ if sd_metadata is None:
+ sd_metadata = shared.readfile(sd_metadata_file, lock=True) if os.path.isfile(sd_metadata_file) else {}
+
+
def read_metadata_from_safetensors(filename):
global sd_metadata # pylint: disable=global-statement
if sd_metadata is None:
diff --git a/modules/sd_detect.py b/modules/sd_detect.py
index 062bb32e1..fe3d325c9 100644
--- a/modules/sd_detect.py
+++ b/modules/sd_detect.py
@@ -71,6 +71,8 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False):
guess = 'Stable Cascade'
if 'pixart-sigma' in f.lower():
guess = 'PixArt-Sigma'
+ if 'sana' in f.lower():
+ guess = 'Sana'
if 'lumina-next' in f.lower():
guess = 'Lumina-Next'
if 'kolors' in f.lower():
@@ -94,6 +96,17 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False):
guess = 'FLUX'
if size > 11000 and size < 16000:
warn(f'Model detected as FLUX UNET model, but attempting to load a base model: {op}={f} size={size} MB')
+ # guess for diffusers
+ index = os.path.join(f, 'model_index.json')
+ if os.path.exists(index) and os.path.isfile(index):
+ index = shared.readfile(index, silent=True)
+ cls = index.get('_class_name', None)
+ if cls is not None:
+ pipeline = getattr(diffusers, cls)
+ if 'Flux' in pipeline.__name__:
+ guess = 'FLUX'
+ if 'StableDiffusion3' in pipeline.__name__:
+ guess = 'Stable Diffusion 3'
# switch for specific variant
if guess == 'Stable Diffusion' and 'inpaint' in f.lower():
guess = 'Stable Diffusion Inpaint'
@@ -105,7 +118,7 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False):
guess = 'Stable Diffusion XL Instruct'
# get actual pipeline
pipeline = shared_items.get_pipelines().get(guess, None) if pipeline is None else pipeline
- if not quiet:
+ if debug_load is not None:
shared.log.info(f'Autodetect {op}: detect="{guess}" class={getattr(pipeline, "__name__", None)} file="{f}" size={size}MB')
t0 = time.time()
keys = model_tools.get_safetensor_keys(f)
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index e9ac1be92..688af0f34 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -53,7 +53,7 @@ def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
# this file is always 404, prevent making request
- if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
+ if (url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14') and args[0] == 'added_tokens.json':
return None
try:
diff --git a/modules/sd_hijack_accelerate.py b/modules/sd_hijack_accelerate.py
index 90eac5c4e..f8cf8983f 100644
--- a/modules/sd_hijack_accelerate.py
+++ b/modules/sd_hijack_accelerate.py
@@ -35,10 +35,10 @@ def hijack_set_module_tensor(
with devices.inference_context():
# note: majority of time is spent on .to(old_value.dtype)
if tensor_name in module._buffers: # pylint: disable=protected-access
- module._buffers[tensor_name] = value.to(device, old_value.dtype, non_blocking=True) # pylint: disable=protected-access
+ module._buffers[tensor_name] = value.to(device, old_value.dtype) # pylint: disable=protected-access
elif value is not None or not devices.same_device(torch.device(device), module._parameters[tensor_name].device): # pylint: disable=protected-access
param_cls = type(module._parameters[tensor_name]) # pylint: disable=protected-access
- module._parameters[tensor_name] = param_cls(value, requires_grad=old_value.requires_grad).to(device, old_value.dtype, non_blocking=True) # pylint: disable=protected-access
+ module._parameters[tensor_name] = param_cls(value, requires_grad=old_value.requires_grad).to(device, old_value.dtype) # pylint: disable=protected-access
t1 = time.time()
tensor_to_timer += (t1 - t0)
@@ -63,10 +63,10 @@ def hijack_set_module_tensor_simple(
old_value = getattr(module, tensor_name)
with devices.inference_context():
if tensor_name in module._buffers: # pylint: disable=protected-access
- module._buffers[tensor_name] = value.to(device, non_blocking=True) # pylint: disable=protected-access
+ module._buffers[tensor_name] = value.to(device) # pylint: disable=protected-access
elif value is not None or not devices.same_device(torch.device(device), module._parameters[tensor_name].device): # pylint: disable=protected-access
param_cls = type(module._parameters[tensor_name]) # pylint: disable=protected-access
- module._parameters[tensor_name] = param_cls(value, requires_grad=old_value.requires_grad).to(device, non_blocking=True) # pylint: disable=protected-access
+ module._parameters[tensor_name] = param_cls(value, requires_grad=old_value.requires_grad).to(device) # pylint: disable=protected-access
t1 = time.time()
tensor_to_timer += (t1 - t0)
diff --git a/modules/sd_hijack_hypertile.py b/modules/sd_hijack_hypertile.py
index dbf977b8d..69c4163dc 100644
--- a/modules/sd_hijack_hypertile.py
+++ b/modules/sd_hijack_hypertile.py
@@ -112,7 +112,7 @@ def wrapper(*args, **kwargs):
out = forward(x, *args[1:], **kwargs)
return out
if x.ndim == 4: # VAE
- # TODO hypertile vae breaks for diffusers when using non-standard sizes
+ # TODO hypertile: vae breaks when using non-standard sizes
if nh * nw > 1:
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw)
out = forward(x, *args[1:], **kwargs)
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index d8d356071..66040c5aa 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -70,7 +70,7 @@ def hijack_ddpm_edit():
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
- CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
+ CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: (kwargs.update({'act_layer': GELUHijack}) and False) or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 # pylint: disable=unnecessary-lambda-assignment
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) # pylint: disable=unnecessary-lambda-assignment
diff --git a/modules/sd_models.py b/modules/sd_models.py
index cf1921a36..087e8ecfc 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -13,10 +13,11 @@
from rich import progress # pylint: disable=redefined-builtin
import torch
import safetensors.torch
+import accelerate
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from modules import paths, shared, shared_state, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, sd_models_config, sd_models_compile, sd_hijack_accelerate, sd_detect
-from modules.timer import Timer
+from modules.timer import Timer, process as process_timer
from modules.memstats import memory_stats
from modules.modeldata import model_data
from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closet_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import
@@ -33,6 +34,8 @@
debug_process = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
diffusers_version = int(diffusers.__version__.split('.')[1])
checkpoint_tiles = checkpoint_titles # legacy compatibility
+should_offload = ['sc', 'sd3', 'f1', 'hunyuandit', 'auraflow', 'omnigen']
+offload_hook_instance = None
class NoWatermark:
@@ -279,7 +282,7 @@ def eval_model(model, op=None, sd_model=None): # pylint: disable=unused-argument
model.eval()
return model
sd_model = sd_models_compile.apply_compile_to_model(sd_model, eval_model, ["Model", "VAE", "Text Encoder"], op="eval")
- if len(shared.opts.torchao_quantization) > 0:
+ if len(shared.opts.torchao_quantization) > 0 and shared.opts.torchao_quantization_mode == 'post':
sd_model = sd_models_compile.torchao_quantization(sd_model)
if shared.opts.opt_channelslast and hasattr(sd_model, 'unet'):
@@ -310,6 +313,7 @@ def set_accelerate(sd_model):
def set_diffuser_offload(sd_model, op: str = 'model'):
+ t0 = time.time()
if not shared.native:
shared.log.warning('Attempting to use offload with backend=original')
return
@@ -318,13 +322,17 @@ def set_diffuser_offload(sd_model, op: str = 'model'):
return
if not (hasattr(sd_model, "has_accelerate") and sd_model.has_accelerate):
sd_model.has_accelerate = False
- if hasattr(sd_model, 'maybe_free_model_hooks') and shared.opts.diffusers_offload_mode == "none":
- shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode}')
- sd_model.maybe_free_model_hooks()
- sd_model.has_accelerate = False
- if hasattr(sd_model, "enable_model_cpu_offload") and shared.opts.diffusers_offload_mode == "model":
+ if shared.opts.diffusers_offload_mode == "none":
+ if shared.sd_model_type in should_offload:
+ shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} type={shared.sd_model.__class__.__name__} large model')
+ else:
+ shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
+ if hasattr(sd_model, 'maybe_free_model_hooks'):
+ sd_model.maybe_free_model_hooks()
+ sd_model.has_accelerate = False
+ if shared.opts.diffusers_offload_mode == "model" and hasattr(sd_model, "enable_model_cpu_offload"):
try:
- shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode}')
+ shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
shared.opts.diffusers_move_base = False
shared.opts.diffusers_move_unet = False
@@ -337,9 +345,9 @@ def set_diffuser_offload(sd_model, op: str = 'model'):
set_accelerate(sd_model)
except Exception as e:
shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}')
- if hasattr(sd_model, "enable_sequential_cpu_offload") and shared.opts.diffusers_offload_mode == "sequential":
+ if shared.opts.diffusers_offload_mode == "sequential" and hasattr(sd_model, "enable_sequential_cpu_offload"):
try:
- shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode}')
+ shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
shared.opts.diffusers_move_base = False
shared.opts.diffusers_move_unet = False
@@ -358,74 +366,159 @@ def set_diffuser_offload(sd_model, op: str = 'model'):
except Exception as e:
shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}')
if shared.opts.diffusers_offload_mode == "balanced":
- try:
- shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode}')
- sd_model = apply_balanced_offload(sd_model)
- except Exception as e:
- shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}')
-
-
-def apply_balanced_offload(sd_model):
- from accelerate import infer_auto_device_map, dispatch_model
- from accelerate.hooks import add_hook_to_module, remove_hook_from_module, ModelHook
+ sd_model = apply_balanced_offload(sd_model)
+ process_timer.add('offload', time.time() - t0)
+
+
+class OffloadHook(accelerate.hooks.ModelHook):
+ def __init__(self, checkpoint_name):
+ if shared.opts.diffusers_offload_max_gpu_memory > 1:
+ shared.opts.diffusers_offload_max_gpu_memory = 0.75
+ if shared.opts.diffusers_offload_max_cpu_memory > 1:
+ shared.opts.diffusers_offload_max_cpu_memory = 0.75
+ self.checkpoint_name = checkpoint_name
+ self.min_watermark = shared.opts.diffusers_offload_min_gpu_memory
+ self.max_watermark = shared.opts.diffusers_offload_max_gpu_memory
+ self.cpu_watermark = shared.opts.diffusers_offload_max_cpu_memory
+ self.gpu = int(shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory * 1024*1024*1024)
+ self.cpu = int(shared.cpu_memory * shared.opts.diffusers_offload_max_cpu_memory * 1024*1024*1024)
+ self.offload_map = {}
+ self.param_map = {}
+ gpu = f'{shared.gpu_memory * shared.opts.diffusers_offload_min_gpu_memory:.3f}-{shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory}:{shared.gpu_memory}'
+ shared.log.info(f'Offload: type=balanced op=init watermark={self.min_watermark}-{self.max_watermark} gpu={gpu} cpu={shared.cpu_memory:.3f} limit={shared.opts.cuda_mem_fraction:.2f}')
+ self.validate()
+ super().__init__()
+
+ def validate(self):
+ if shared.opts.diffusers_offload_mode != 'balanced':
+ return
+ if shared.opts.diffusers_offload_min_gpu_memory < 0 or shared.opts.diffusers_offload_min_gpu_memory > 1:
+ shared.opts.diffusers_offload_min_gpu_memory = 0.25
+ shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} invalid value')
+ if shared.opts.diffusers_offload_max_gpu_memory < 0.1 or shared.opts.diffusers_offload_max_gpu_memory > 1:
+ shared.opts.diffusers_offload_max_gpu_memory = 0.75
+ shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} invalid value')
+ if shared.opts.diffusers_offload_min_gpu_memory > shared.opts.diffusers_offload_max_gpu_memory:
+ shared.opts.diffusers_offload_min_gpu_memory = shared.opts.diffusers_offload_max_gpu_memory
+ shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} reset')
+ if shared.opts.diffusers_offload_max_gpu_memory * shared.gpu_memory < 4:
+ shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} low memory')
+
+ def model_size(self):
+ return sum(self.offload_map.values())
+
+ def init_hook(self, module):
+ return module
+
+ def pre_forward(self, module, *args, **kwargs):
+ if devices.normalize_device(module.device) != devices.normalize_device(devices.device):
+ device_index = torch.device(devices.device).index
+ if device_index is None:
+ device_index = 0
+ max_memory = { device_index: self.gpu, "cpu": self.cpu }
+ device_map = getattr(module, "balanced_offload_device_map", None)
+ if device_map is None or max_memory != getattr(module, "balanced_offload_max_memory", None):
+ device_map = accelerate.infer_auto_device_map(module, max_memory=max_memory)
+ offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__))
+ module = accelerate.dispatch_model(module, device_map=device_map, offload_dir=offload_dir)
+ module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
+ module.balanced_offload_device_map = device_map
+ module.balanced_offload_max_memory = max_memory
+ return args, kwargs
+
+ def post_forward(self, module, output):
+ return output
+
+ def detach_hook(self, module):
+ return module
+
+
+def apply_balanced_offload(sd_model, exclude=[]):
+ global offload_hook_instance # pylint: disable=global-statement
+ if shared.opts.diffusers_offload_mode != "balanced":
+ return sd_model
+ t0 = time.time()
excluded = ['OmniGenPipeline']
if sd_model.__class__.__name__ in excluded:
return sd_model
-
- class dispatch_from_cpu_hook(ModelHook):
- def init_hook(self, module):
- return module
-
- def pre_forward(self, module, *args, **kwargs):
- if devices.normalize_device(module.device) != devices.normalize_device(devices.device):
- device_index = torch.device(devices.device).index
- if device_index is None:
- device_index = 0
- max_memory = {
- device_index: f"{shared.opts.diffusers_offload_max_gpu_memory}GiB",
- "cpu": f"{shared.opts.diffusers_offload_max_cpu_memory}GiB",
- }
- device_map = infer_auto_device_map(module, max_memory=max_memory)
- module = remove_hook_from_module(module, recurse=True)
- offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__))
- module = dispatch_model(module, device_map=device_map, offload_dir=offload_dir)
- module = add_hook_to_module(module, dispatch_from_cpu_hook(), append=True)
- module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
- return args, kwargs
-
- def post_forward(self, module, output):
- return output
-
- def detach_hook(self, module):
- return module
+ cached = True
+ checkpoint_name = sd_model.sd_checkpoint_info.name if getattr(sd_model, "sd_checkpoint_info", None) is not None else None
+ if checkpoint_name is None:
+ checkpoint_name = sd_model.__class__.__name__
+ if offload_hook_instance is None or offload_hook_instance.min_watermark != shared.opts.diffusers_offload_min_gpu_memory or offload_hook_instance.max_watermark != shared.opts.diffusers_offload_max_gpu_memory or checkpoint_name != offload_hook_instance.checkpoint_name:
+ cached = False
+ offload_hook_instance = OffloadHook(checkpoint_name)
+
+ def get_pipe_modules(pipe):
+ if hasattr(pipe, "_internal_dict"):
+ modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access
+ else:
+ modules_names = get_signature(pipe).keys()
+ modules_names = [m for m in modules_names if m not in exclude and not m.startswith('_')]
+ modules = {}
+ for module_name in modules_names:
+ module_size = offload_hook_instance.offload_map.get(module_name, None)
+ if module_size is None:
+ module = getattr(pipe, module_name, None)
+ if not isinstance(module, torch.nn.Module):
+ continue
+ try:
+ module_size = sum(p.numel() * p.element_size() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024
+ param_num = sum(p.numel() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024
+ except Exception as e:
+ shared.log.error(f'Offload: type=balanced op=calc module={module_name} {e}')
+ module_size = 0
+ offload_hook_instance.offload_map[module_name] = module_size
+ offload_hook_instance.param_map[module_name] = param_num
+ modules[module_name] = module_size
+ modules = sorted(modules.items(), key=lambda x: x[1], reverse=True)
+ return modules
def apply_balanced_offload_to_module(pipe):
+ used_gpu, used_ram = devices.torch_gc(fast=True)
if hasattr(pipe, "pipe"):
apply_balanced_offload_to_module(pipe.pipe)
if hasattr(pipe, "_internal_dict"):
keys = pipe._internal_dict.keys() # pylint: disable=protected-access
else:
- keys = get_signature(shared.sd_model).keys()
- for module_name in keys: # pylint: disable=protected-access
+ keys = get_signature(pipe).keys()
+ keys = [k for k in keys if k not in exclude and not k.startswith('_')]
+ for module_name, module_size in get_pipe_modules(pipe): # pylint: disable=protected-access
module = getattr(pipe, module_name, None)
- if isinstance(module, torch.nn.Module):
- checkpoint_name = pipe.sd_checkpoint_info.name if getattr(pipe, "sd_checkpoint_info", None) is not None else None
- if checkpoint_name is None:
- checkpoint_name = pipe.__class__.__name__
- offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name)
- module = remove_hook_from_module(module, recurse=True)
- try:
- module = module.to("cpu")
- module.offload_dir = offload_dir
- network_layer_name = getattr(module, "network_layer_name", None)
- module = add_hook_to_module(module, dispatch_from_cpu_hook(), append=True)
- module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
- if network_layer_name:
- module.network_layer_name = network_layer_name
- except Exception as e:
- if 'bitsandbytes' not in str(e):
- shared.log.error(f'Balanced offload: module={module_name} {e}')
- devices.torch_gc(fast=True)
+ if module is None:
+ continue
+ network_layer_name = getattr(module, "network_layer_name", None)
+ device_map = getattr(module, "balanced_offload_device_map", None)
+ max_memory = getattr(module, "balanced_offload_max_memory", None)
+ module = accelerate.hooks.remove_hook_from_module(module, recurse=True)
+ perc_gpu = used_gpu / shared.gpu_memory
+ try:
+ prev_gpu = used_gpu
+ do_offload = (perc_gpu > shared.opts.diffusers_offload_min_gpu_memory) and (module.device != devices.cpu)
+ if do_offload:
+ module = module.to(devices.cpu, non_blocking=True)
+ used_gpu -= module_size
+ if not cached:
+ shared.log.debug(f'Offload: type=balanced module={module_name} cls={module.__class__.__name__} dtype={module.dtype} quant={getattr(module, "quantization_method", None)} params={offload_hook_instance.param_map[module_name]:.3f} size={offload_hook_instance.offload_map[module_name]:.3f}')
+ debug_move(f'Offload: type=balanced op={"move" if do_offload else "skip"} gpu={prev_gpu:.3f}:{used_gpu:.3f} perc={perc_gpu:.2f} ram={used_ram:.3f} current={module.device} dtype={module.dtype} quant={getattr(module, "quantization_method", None)} module={module.__class__.__name__} size={module_size:.3f}')
+ except Exception as e:
+ if 'out of memory' in str(e):
+ devices.torch_gc(fast=True, force=True, reason='oom')
+ elif 'bitsandbytes' in str(e):
+ pass
+ else:
+ shared.log.error(f'Offload: type=balanced op=apply module={module_name} {e}')
+ if os.environ.get('SD_MOVE_DEBUG', None):
+ errors.display(e, f'Offload: type=balanced op=apply module={module_name}')
+ module.offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name)
+ module = accelerate.hooks.add_hook_to_module(module, offload_hook_instance, append=True)
+ module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
+ if network_layer_name:
+ module.network_layer_name = network_layer_name
+ if device_map and max_memory:
+ module.balanced_offload_device_map = device_map
+ module.balanced_offload_max_memory = max_memory
+ devices.torch_gc(fast=True, force=True, reason='offload')
apply_balanced_offload_to_module(sd_model)
if hasattr(sd_model, "pipe"):
@@ -435,6 +528,12 @@ def apply_balanced_offload_to_module(pipe):
if hasattr(sd_model, "decoder_pipe"):
apply_balanced_offload_to_module(sd_model.decoder_pipe)
set_accelerate(sd_model)
+ t = time.time() - t0
+ process_timer.add('offload', t)
+ fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
+ debug_move(f'Apply offload: time={t:.2f} type=balanced fn={fn}')
+ if not cached:
+ shared.log.info(f'Offload: type=balanced op=apply class={sd_model.__class__.__name__} modules={len(offload_hook_instance.offload_map)} size={offload_hook_instance.model_size():.3f}')
return sd_model
@@ -479,7 +578,7 @@ def move_model(model, device=None, force=False):
shared.log.error(f'Model move execution device: device={device} {e}')
if getattr(model, 'has_accelerate', False) and not force:
return
- if hasattr(model, "device") and devices.normalize_device(model.device) == devices.normalize_device(device):
+ if hasattr(model, "device") and devices.normalize_device(model.device) == devices.normalize_device(device) and not force:
return
try:
t0 = time.time()
@@ -512,7 +611,10 @@ def move_model(model, device=None, force=False):
except Exception as e1:
t1 = time.time()
shared.log.error(f'Model move: device={device} {e1}')
- if os.environ.get('SD_MOVE_DEBUG', None) or (t1-t0) > 0.1:
+ if 'move' not in process_timer.records:
+ process_timer.records['move'] = 0
+ process_timer.records['move'] += t1 - t0
+ if os.environ.get('SD_MOVE_DEBUG', None) or (t1-t0) > 2:
shared.log.debug(f'Model move: device={device} class={model.__class__.__name__} accelerate={getattr(model, "has_accelerate", False)} fn={fn} time={t1-t0:.2f}') # pylint: disable=protected-access
devices.torch_gc()
@@ -612,6 +714,9 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='
elif model_type in ['PixArt-Sigma']: # forced pipeline
from modules.model_pixart import load_pixart
sd_model = load_pixart(checkpoint_info, diffusers_load_config)
+ elif model_type in ['Sana']: # forced pipeline
+ from modules.model_sana import load_sana
+ sd_model = load_sana(checkpoint_info, diffusers_load_config)
elif model_type in ['Lumina-Next']: # forced pipeline
from modules.model_lumina import load_lumina
sd_model = load_lumina(checkpoint_info, diffusers_load_config)
@@ -771,7 +876,7 @@ def load_diffuser_file(model_type, pipeline, checkpoint_info, diffusers_load_con
return sd_model
-def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'): # pylint: disable=unused-argument
+def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model', revision=None): # pylint: disable=unused-argument
if timer is None:
timer = Timer()
logging.getLogger("diffusers").setLevel(logging.ERROR)
@@ -784,6 +889,8 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No
"requires_safety_checker": False, # sd15 specific but we cant know ahead of time
# "use_safetensors": True,
}
+ if revision is not None:
+ diffusers_load_config['revision'] = revision
if shared.opts.diffusers_model_load_variant != 'default':
diffusers_load_config['variant'] = shared.opts.diffusers_model_load_variant
if shared.opts.diffusers_pipeline == 'Custom Diffusers Pipeline' and len(shared.opts.custom_diffusers_pipeline) > 0:
@@ -925,7 +1032,8 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No
shared.log.error(f"Load {op}: {e}")
errors.display(e, "Model")
- devices.torch_gc(force=True)
+ if shared.opts.diffusers_offload_mode != 'balanced':
+ devices.torch_gc(force=True)
if sd_model is not None:
script_callbacks.model_loaded_callback(sd_model)
@@ -961,6 +1069,11 @@ def get_signature(cls):
return signature.parameters
+def get_call(cls):
+ signature = inspect.signature(cls.__call__, follow_wrapped=True, eval_str=True)
+ return signature.parameters
+
+
def switch_pipe(cls: diffusers.DiffusionPipeline, pipeline: diffusers.DiffusionPipeline = None, force = False, args = {}):
"""
args:
@@ -1077,6 +1190,9 @@ def set_diffuser_pipe(pipe, new_pipe_type):
'OmniGenPipeline',
'StableDiffusion3ControlNetPipeline',
'InstantIRPipeline',
+ 'FluxFillPipeline',
+ 'FluxControlPipeline',
+ 'StableVideoDiffusionPipeline',
]
n = getattr(pipe.__class__, '__name__', '')
@@ -1345,7 +1461,7 @@ def reload_text_encoder(initial=False):
set_t5(pipe=shared.sd_model, module='text_encoder_3', t5=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
-def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model', force=False):
+def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model', force=False, revision=None):
load_dict = shared.opts.sd_model_dict != model_data.sd_dict
from modules import lowvram, sd_hijack
checkpoint_info = info or select_checkpoint(op=op) # are we selecting model or dictionary
@@ -1380,7 +1496,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model',
unload_model_weights(op=op)
sd_model = None
timer = Timer()
- # TODO implement caching after diffusers implement state_dict loading
+ # TODO model loader: implement model in-memory caching
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) if not shared.native else None
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
timer.record("config")
@@ -1390,7 +1506,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model',
load_model(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op)
model_data.sd_dict = shared.opts.sd_model_dict
else:
- load_diffuser(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op)
+ load_diffuser(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op, revision=revision)
if load_dict and next_checkpoint_info is not None:
model_data.sd_dict = shared.opts.sd_model_dict
shared.opts.data["sd_model_checkpoint"] = next_checkpoint_info.title
@@ -1441,10 +1557,17 @@ def disable_offload(sd_model):
from accelerate.hooks import remove_hook_from_module
if not getattr(sd_model, 'has_accelerate', False):
return
- if hasattr(sd_model, 'components'):
- for _name, model in sd_model.components.items():
- if isinstance(model, torch.nn.Module):
- remove_hook_from_module(model, recurse=True)
+ if hasattr(sd_model, "_internal_dict"):
+ keys = sd_model._internal_dict.keys() # pylint: disable=protected-access
+ else:
+ keys = get_signature(sd_model).keys()
+ for module_name in keys: # pylint: disable=protected-access
+ module = getattr(sd_model, module_name, None)
+ if isinstance(module, torch.nn.Module):
+ network_layer_name = getattr(module, "network_layer_name", None)
+ module = remove_hook_from_module(module, recurse=True)
+ if network_layer_name:
+ module.network_layer_name = network_layer_name
sd_model.has_accelerate = False
diff --git a/modules/sd_models_compile.py b/modules/sd_models_compile.py
index 91ed84ded..20a7d7de2 100644
--- a/modules/sd_models_compile.py
+++ b/modules/sd_models_compile.py
@@ -47,7 +47,7 @@ def apply_compile_to_model(sd_model, function, options, op=None):
sd_model.prior_pipe.prior.clip_txt_pooled_mapper = backup_clip_txt_pooled_mapper
if "Text Encoder" in options:
if hasattr(sd_model, 'text_encoder') and hasattr(sd_model.text_encoder, 'config'):
- if hasattr(sd_model, 'decoder_pipe') and hasattr(sd_model.decoder_pipe, 'text_encoder'):
+ if hasattr(sd_model, 'decoder_pipe') and hasattr(sd_model.decoder_pipe, 'text_encoder') and hasattr(sd_model.decoder_pipe.text_encoder, 'config'):
sd_model.decoder_pipe.text_encoder = function(sd_model.decoder_pipe.text_encoder, op="decoder_pipe.text_encoder", sd_model=sd_model)
else:
if op == "nncf" and sd_model.text_encoder.__class__.__name__ in {"T5EncoderModel", "UMT5EncoderModel"}:
@@ -76,7 +76,7 @@ def apply_compile_to_model(sd_model, function, options, op=None):
dtype=torch.float32 if devices.dtype != torch.bfloat16 else torch.bfloat16
)
sd_model.text_encoder_3 = function(sd_model.text_encoder_3, op="text_encoder_3", sd_model=sd_model)
- if hasattr(sd_model, 'prior_pipe') and hasattr(sd_model.prior_pipe, 'text_encoder'):
+ if hasattr(sd_model, 'prior_pipe') and hasattr(sd_model.prior_pipe, 'text_encoder') and hasattr(sd_model.prior_pipe.text_encoder, 'config'):
sd_model.prior_pipe.text_encoder = function(sd_model.prior_pipe.text_encoder, op="prior_pipe.text_encoder", sd_model=sd_model)
if "VAE" in options:
if hasattr(sd_model, 'vae') and hasattr(sd_model.vae, 'decode'):
@@ -505,51 +505,26 @@ def compile_diffusers(sd_model):
def torchao_quantization(sd_model):
try:
- install('torchao', quiet=True)
+ install('torchao==0.7.0', quiet=True)
from torchao import quantization as q
except Exception as e:
shared.log.error(f"Quantization: type=TorchAO quantization not supported: {e}")
return sd_model
- if shared.opts.torchao_quantization_type == "int8+act":
- fn = q.int8_dynamic_activation_int8_weight
- elif shared.opts.torchao_quantization_type == "int8":
- fn = q.int8_weight_only
- elif shared.opts.torchao_quantization_type == "int4":
- fn = q.int4_weight_only
- elif shared.opts.torchao_quantization_type == "fp8+act":
- fn = q.float8_dynamic_activation_float8_weight
- elif shared.opts.torchao_quantization_type == "fp8":
- fn = q.float8_weight_only
- elif shared.opts.torchao_quantization_type == "fpx":
- fn = q.fpx_weight_only
- else:
+
+ fn = getattr(q, shared.opts.torchao_quantization_type, None)
+ if fn is None:
shared.log.error(f"Quantization: type=TorchAO type={shared.opts.torchao_quantization_type} not supported")
return sd_model
+ def torchao_model(model, op=None, sd_model=None): # pylint: disable=unused-argument
+ q.quantize_(model, fn(), device=devices.device)
+ return model
+
shared.log.info(f"Quantization: type=TorchAO pipe={sd_model.__class__.__name__} quant={shared.opts.torchao_quantization_type} fn={fn} targets={shared.opts.torchao_quantization}")
try:
t0 = time.time()
- modules = []
- if hasattr(sd_model, 'unet') and 'Model' in shared.opts.torchao_quantization:
- modules.append('unet')
- q.quantize_(sd_model.unet, fn(), device=devices.device)
- if hasattr(sd_model, 'transformer') and 'Model' in shared.opts.torchao_quantization:
- modules.append('transformer')
- q.quantize_(sd_model.transformer, fn(), device=devices.device)
- # sd_model.transformer = q.autoquant(sd_model.transformer, error_on_unseen=False)
- if hasattr(sd_model, 'vae') and 'VAE' in shared.opts.torchao_quantization:
- modules.append('vae')
- q.quantize_(sd_model.vae, fn(), device=devices.device)
- if hasattr(sd_model, 'text_encoder') and 'Text Encoder' in shared.opts.torchao_quantization:
- modules.append('te1')
- q.quantize_(sd_model.text_encoder, fn(), device=devices.device)
- if hasattr(sd_model, 'text_encoder_2') and 'Text Encoder' in shared.opts.torchao_quantization:
- modules.append('te2')
- q.quantize_(sd_model.text_encoder_2, fn(), device=devices.device)
- if hasattr(sd_model, 'text_encoder_3') and 'Text Encoder' in shared.opts.torchao_quantization:
- modules.append('te3')
- q.quantize_(sd_model.text_encoder_3, fn(), device=devices.device)
+ apply_compile_to_model(sd_model, torchao_model, shared.opts.torchao_quantization, op="torchao")
t1 = time.time()
- shared.log.info(f"Quantization: type=TorchAO modules={modules} time={t1-t0:.2f}")
+ shared.log.info(f"Quantization: type=TorchAO time={t1-t0:.2f}")
except Exception as e:
shared.log.error(f"Quantization: type=TorchAO {e}")
setup_logging() # torchao uses dynamo which messes with logging so reset is needed
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index e560744dd..dc58a2419 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -3,6 +3,7 @@
from modules import shared
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # pylint: disable=unused-import
+
debug = shared.log.trace if os.environ.get('SD_SAMPLER_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: SAMPLER')
all_samplers = []
@@ -47,6 +48,8 @@ def visible_sampler_names():
def create_sampler(name, model):
+ if name is None or name == 'None':
+ return model.scheduler
try:
current = model.scheduler.__class__.__name__
except Exception:
@@ -73,15 +76,15 @@ def create_sampler(name, model):
shared.log.debug(f'Sampler: sampler="{name}" config={config.options}')
return sampler
elif shared.native:
- FlowModels = ['Flux', 'StableDiffusion3', 'Lumina', 'AuraFlow']
+ FlowModels = ['Flux', 'StableDiffusion3', 'Lumina', 'AuraFlow', 'Sana']
if 'KDiffusion' in model.__class__.__name__:
return None
- if any(x in model.__class__.__name__ for x in FlowModels) and 'FlowMatch' not in name:
- shared.log.warning(f'Sampler: default={current} target="{name}" class={model.__class__.__name__} linear scheduler unsupported')
- return None
if not any(x in model.__class__.__name__ for x in FlowModels) and 'FlowMatch' in name:
shared.log.warning(f'Sampler: default={current} target="{name}" class={model.__class__.__name__} flow-match scheduler unsupported')
return None
+ # if any(x in model.__class__.__name__ for x in FlowModels) and 'FlowMatch' not in name:
+ # shared.log.warning(f'Sampler: default={current} target="{name}" class={model.__class__.__name__} linear scheduler unsupported')
+ # return None
sampler = config.constructor(model)
if sampler is None:
sampler = config.constructor(model)
@@ -91,7 +94,8 @@ def create_sampler(name, model):
if hasattr(model, "prior_pipe") and hasattr(model.prior_pipe, "scheduler"):
model.prior_pipe.scheduler = sampler.sampler
model.prior_pipe.scheduler.config.clip_sample = False
- shared.log.debug(f'Sampler: sampler="{sampler.name}" class="{model.scheduler.__class__.__name__} config={sampler.config}')
+ clean_config = {k: v for k, v in sampler.config.items() if v is not None and v is not False}
+ shared.log.debug(f'Sampler: sampler="{sampler.name}" class="{model.scheduler.__class__.__name__} config={clean_config}')
return sampler.sampler
else:
return None
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index a487fe9b7..a90ceec27 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -1,13 +1,15 @@
+import time
import threading
from collections import namedtuple
import torch
import torchvision.transforms as T
from PIL import Image
-from modules import shared, devices, processing, images, sd_vae_approx, sd_vae_taesd, sd_vae_stablecascade, sd_samplers
+from modules import shared, devices, processing, images, sd_vae_approx, sd_vae_taesd, sd_vae_stablecascade, sd_samplers, timer
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
approximation_indexes = { "Simple": 0, "Approximate": 1, "TAESD": 2, "Full VAE": 3 }
+flow_models = ['f1', 'sd3', 'lumina', 'auraflow', 'sana']
warned = False
queue_lock = threading.Lock()
@@ -33,13 +35,12 @@ def setup_img2img_steps(p, steps=None):
def single_sample_to_image(sample, approximation=None):
with queue_lock:
- sd_cascade = False
+ t0 = time.time()
if approximation is None:
approximation = approximation_indexes.get(shared.opts.show_progress_type, None)
if approximation is None:
warn_once('Unknown decode type')
approximation = 0
- # normal sample is [4,64,64]
try:
if sample.dtype == torch.bfloat16 and (approximation == 0 or approximation == 1):
sample = sample.to(torch.float16)
@@ -48,22 +49,15 @@ def single_sample_to_image(sample, approximation=None):
if len(sample.shape) > 4: # likely unknown video latent (e.g. svd)
return Image.new(mode="RGB", size=(512, 512))
- if len(sample) == 16: # sd_cascade
- sd_cascade = True
if len(sample.shape) == 4 and sample.shape[0]: # likely animatediff latent
sample = sample.permute(1, 0, 2, 3)[0]
- if shared.native: # [-x,x] to [-5,5]
- sample_max = torch.max(sample)
- if sample_max > 5:
- sample = sample * (5 / sample_max)
- sample_min = torch.min(sample)
- if sample_min < -5:
- sample = sample * (5 / abs(sample_min))
-
if approximation == 2: # TAESD
+ if shared.opts.live_preview_downscale and (sample.shape[-1] > 128 or sample.shape[-2] > 128):
+ scale = 128 / max(sample.shape[-1], sample.shape[-2])
+ sample = torch.nn.functional.interpolate(sample.unsqueeze(0), scale_factor=[scale, scale], mode='bilinear', align_corners=False)[0]
x_sample = sd_vae_taesd.decode(sample)
x_sample = (1.0 + x_sample) / 2.0 # preview requires smaller range
- elif sd_cascade and approximation != 3:
+ elif shared.sd_model_type == 'sc' and approximation != 3:
x_sample = sd_vae_stablecascade.decode(sample)
elif approximation == 0: # Simple
x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
@@ -84,6 +78,8 @@ def single_sample_to_image(sample, approximation=None):
except Exception as e:
warn_once(f'Preview: {e}')
image = Image.new(mode="RGB", size=(512, 512))
+ t1 = time.time()
+ timer.process.add('preview', t1 - t0)
return image
diff --git a/modules/sd_samplers_diffusers.py b/modules/sd_samplers_diffusers.py
index 60c75b64e..6c05b2045 100644
--- a/modules/sd_samplers_diffusers.py
+++ b/modules/sd_samplers_diffusers.py
@@ -4,13 +4,12 @@
import inspect
import diffusers
from modules import shared, errors
-from modules import sd_samplers_common
+from modules.sd_samplers_common import SamplerData, flow_models
debug = shared.log.trace if os.environ.get('SD_SAMPLER_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: SAMPLER')
-
try:
from diffusers import (
CMStochasticIterativeScheduler,
@@ -52,6 +51,8 @@
from modules.schedulers.scheduler_dc import DCSolverMultistepScheduler # pylint: disable=ungrouped-imports
from modules.schedulers.scheduler_vdm import VDMScheduler # pylint: disable=ungrouped-imports
from modules.schedulers.scheduler_dpm_flowmatch import FlowMatchDPMSolverMultistepScheduler # pylint: disable=ungrouped-imports
+ from modules.schedulers.scheduler_bdia import BDIA_DDIMScheduler # pylint: disable=ungrouped-imports
+ from modules.schedulers.scheduler_ufogen import UFOGenScheduler # pylint: disable=ungrouped-imports
except Exception as e:
shared.log.error(f'Diffusers import error: version={diffusers.__version__} error: {e}')
if os.environ.get('SD_SAMPLER_DEBUG', None) is not None:
@@ -62,41 +63,43 @@
# prediction_type is ideally set in model as well, but it maybe needed that we do auto-detect of model type in the future
'All': { 'num_train_timesteps': 1000, 'beta_start': 0.0001, 'beta_end': 0.02, 'beta_schedule': 'linear', 'prediction_type': 'epsilon' },
- 'UniPC': { 'predict_x0': True, 'sample_max_value': 1.0, 'solver_order': 2, 'solver_type': 'bh2', 'thresholding': False, 'use_beta_sigmas': False, 'use_exponential_sigmas': False, 'use_karras_sigmas': False, 'lower_order_final': True, 'timestep_spacing': 'linspace', 'final_sigmas_type': 'zero', 'rescale_betas_zero_snr': False },
+ 'UniPC': { 'predict_x0': True, 'sample_max_value': 1.0, 'solver_order': 2, 'solver_type': 'bh2', 'thresholding': False, 'use_beta_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_karras_sigmas': False, 'lower_order_final': True, 'timestep_spacing': 'linspace', 'final_sigmas_type': 'zero', 'rescale_betas_zero_snr': False },
'DDIM': { 'clip_sample': False, 'set_alpha_to_one': True, 'steps_offset': 0, 'clip_sample_range': 1.0, 'sample_max_value': 1.0, 'timestep_spacing': 'leading', 'rescale_betas_zero_snr': False, 'thresholding': False },
'Euler': { 'steps_offset': 0, 'interpolation_type': "linear", 'rescale_betas_zero_snr': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'use_beta_sigmas': False, 'use_exponential_sigmas': False, 'use_karras_sigmas': False },
'Euler a': { 'steps_offset': 0, 'rescale_betas_zero_snr': False, 'timestep_spacing': 'linspace' },
'Euler SGM': { 'steps_offset': 0, 'interpolation_type': "linear", 'rescale_betas_zero_snr': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'trailing', 'use_beta_sigmas': False, 'use_exponential_sigmas': False, 'use_karras_sigmas': False, 'prediction_type': "sample" },
'Euler EDM': { 'sigma_schedule': "karras" },
- 'Euler FlowMatch': { 'timestep_spacing': "linspace", 'shift': 1, 'use_dynamic_shifting': False },
+ 'Euler FlowMatch': { 'timestep_spacing': "linspace", 'shift': 1, 'use_dynamic_shifting': False, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False },
- 'DPM++': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'final_sigmas_type': 'sigma_min' },
- 'DPM++ 1S': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 1 },
- 'DPM++ 2M': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 2 },
- 'DPM++ 3M': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 3 },
- 'DPM++ 2M SDE': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "sde-dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 2 },
+ 'DPM++': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'final_sigmas_type': 'sigma_min' },
+ 'DPM++ 1S': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 1 },
+ 'DPM++ 2M': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 2 },
+ 'DPM++ 3M': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 3 },
+ 'DPM++ 2M SDE': { 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "sde-dpmsolver++", 'solver_type': "midpoint", 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False, 'use_lu_lambdas': False, 'final_sigmas_type': 'zero', 'timestep_spacing': 'linspace', 'solver_order': 2 },
'DPM++ 2M EDM': { 'solver_order': 2, 'solver_type': 'midpoint', 'final_sigmas_type': 'zero', 'algorithm_type': 'dpmsolver++' },
'DPM++ Cosine': { 'solver_order': 2, 'sigma_schedule': "exponential", 'prediction_type': "v-prediction" },
'DPM SDE': { 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'noise_sampler_seed': None, 'timestep_spacing': 'linspace', 'steps_offset': 0, },
- 'DPM2 FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver2', 'use_noise_sampler': True },
- 'DPM2a FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver2A', 'use_noise_sampler': True },
- 'DPM2++ 2M FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2M', 'use_noise_sampler': True },
- 'DPM2++ 2S FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2S', 'use_noise_sampler': True },
- 'DPM2++ SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++sde', 'use_noise_sampler': True },
- 'DPM2++ 2M SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2Msde', 'use_noise_sampler': True },
- 'DPM2++ 3M SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 3, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++3Msde', 'use_noise_sampler': True },
+ 'DPM2 FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver2', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 },
+ 'DPM2a FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver2A', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 },
+ 'DPM2++ 2M FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2M', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 },
+ 'DPM2++ 2S FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2S', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 },
+ 'DPM2++ SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++sde', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 },
+ 'DPM2++ 2M SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 2, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++2Msde', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 },
+ 'DPM2++ 3M SDE FlowMatch': { 'shift': 1, 'use_dynamic_shifting': False, 'solver_order': 3, 'sigma_schedule': None, 'use_beta_sigmas': False, 'algorithm_type': 'dpmsolver++3Msde', 'use_noise_sampler': True, 'beta_start': 0.00085, 'beta_end': 0.012 },
'Heun': { 'use_beta_sigmas': False, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'timestep_spacing': 'linspace' },
'Heun FlowMatch': { 'timestep_spacing': "linspace", 'shift': 1 },
- 'DEIS': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "deis", 'solver_type': "logrho", 'lower_order_final': True, 'timestep_spacing': 'linspace', 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False },
- 'SA Solver': {'predictor_order': 2, 'corrector_order': 2, 'thresholding': False, 'lower_order_final': True, 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'timestep_spacing': 'linspace'},
+ 'DEIS': { 'solver_order': 2, 'thresholding': False, 'sample_max_value': 1.0, 'algorithm_type': "deis", 'solver_type': "logrho", 'lower_order_final': True, 'timestep_spacing': 'linspace', 'use_karras_sigmas': False, 'use_exponential_sigmas': False, 'use_flow_sigmas': False, 'use_beta_sigmas': False },
+ 'SA Solver': {'predictor_order': 2, 'corrector_order': 2, 'thresholding': False, 'lower_order_final': True, 'use_karras_sigmas': False, 'use_flow_sigmas': False, 'use_exponential_sigmas': False, 'use_beta_sigmas': False, 'timestep_spacing': 'linspace'},
'DC Solver': { 'beta_start': 0.0001, 'beta_end': 0.02, 'solver_order': 2, 'prediction_type': "epsilon", 'thresholding': False, 'solver_type': 'bh2', 'lower_order_final': True, 'dc_order': 2, 'disable_corrector': [0] },
'VDM Solver': { 'clip_sample_range': 2.0, },
'LCM': { 'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': "scaled_linear", 'set_alpha_to_one': True, 'rescale_betas_zero_snr': False, 'thresholding': False, 'timestep_spacing': 'linspace' },
'TCD': { 'set_alpha_to_one': True, 'rescale_betas_zero_snr': False, 'beta_schedule': 'scaled_linear' },
+ 'UFOGen': {},
+ 'BDIA DDIM': { 'clip_sample': False, 'set_alpha_to_one': True, 'steps_offset': 0, 'clip_sample_range': 1.0, 'sample_max_value': 1.0, 'timestep_spacing': 'leading', 'rescale_betas_zero_snr': False, 'thresholding': False, 'gamma': 1.0 },
'PNDM': { 'skip_prk_steps': False, 'set_alpha_to_one': False, 'steps_offset': 0, 'timestep_spacing': 'linspace' },
'IPNDM': { },
@@ -108,53 +111,55 @@
}
samplers_data_diffusers = [
- sd_samplers_common.SamplerData('Default', None, [], {}),
+ SamplerData('Default', None, [], {}),
- sd_samplers_common.SamplerData('UniPC', lambda model: DiffusionSampler('UniPC', UniPCMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DDIM', lambda model: DiffusionSampler('DDIM', DDIMScheduler, model), [], {}),
- sd_samplers_common.SamplerData('Euler', lambda model: DiffusionSampler('Euler', EulerDiscreteScheduler, model), [], {}),
- sd_samplers_common.SamplerData('Euler a', lambda model: DiffusionSampler('Euler a', EulerAncestralDiscreteScheduler, model), [], {}),
- sd_samplers_common.SamplerData('Euler SGM', lambda model: DiffusionSampler('Euler SGM', EulerDiscreteScheduler, model), [], {}),
- sd_samplers_common.SamplerData('Euler EDM', lambda model: DiffusionSampler('Euler EDM', EDMEulerScheduler, model), [], {}),
- sd_samplers_common.SamplerData('Euler FlowMatch', lambda model: DiffusionSampler('Euler FlowMatch', FlowMatchEulerDiscreteScheduler, model), [], {}),
+ SamplerData('UniPC', lambda model: DiffusionSampler('UniPC', UniPCMultistepScheduler, model), [], {}),
+ SamplerData('DDIM', lambda model: DiffusionSampler('DDIM', DDIMScheduler, model), [], {}),
+ SamplerData('Euler', lambda model: DiffusionSampler('Euler', EulerDiscreteScheduler, model), [], {}),
+ SamplerData('Euler a', lambda model: DiffusionSampler('Euler a', EulerAncestralDiscreteScheduler, model), [], {}),
+ SamplerData('Euler SGM', lambda model: DiffusionSampler('Euler SGM', EulerDiscreteScheduler, model), [], {}),
+ SamplerData('Euler EDM', lambda model: DiffusionSampler('Euler EDM', EDMEulerScheduler, model), [], {}),
+ SamplerData('Euler FlowMatch', lambda model: DiffusionSampler('Euler FlowMatch', FlowMatchEulerDiscreteScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM++', lambda model: DiffusionSampler('DPM++', DPMSolverSinglestepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM++ 1S', lambda model: DiffusionSampler('DPM++ 1S', DPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM++ 2M', lambda model: DiffusionSampler('DPM++ 2M', DPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM++ 3M', lambda model: DiffusionSampler('DPM++ 3M', DPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM++ 2M SDE', lambda model: DiffusionSampler('DPM++ 2M SDE', DPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM++ 2M EDM', lambda model: DiffusionSampler('DPM++ 2M EDM', EDMDPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM++ Cosine', lambda model: DiffusionSampler('DPM++ 2M EDM', CosineDPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM SDE', lambda model: DiffusionSampler('DPM SDE', DPMSolverSDEScheduler, model), [], {}),
+ SamplerData('DPM++', lambda model: DiffusionSampler('DPM++', DPMSolverSinglestepScheduler, model), [], {}),
+ SamplerData('DPM++ 1S', lambda model: DiffusionSampler('DPM++ 1S', DPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM++ 2M', lambda model: DiffusionSampler('DPM++ 2M', DPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM++ 3M', lambda model: DiffusionSampler('DPM++ 3M', DPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM++ 2M SDE', lambda model: DiffusionSampler('DPM++ 2M SDE', DPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM++ 2M EDM', lambda model: DiffusionSampler('DPM++ 2M EDM', EDMDPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM++ Cosine', lambda model: DiffusionSampler('DPM++ 2M EDM', CosineDPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM SDE', lambda model: DiffusionSampler('DPM SDE', DPMSolverSDEScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM2 FlowMatch', lambda model: DiffusionSampler('DPM2 FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM2a FlowMatch', lambda model: DiffusionSampler('DPM2a FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM2++ 2M FlowMatch', lambda model: DiffusionSampler('DPM2++ 2M FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM2++ 2S FlowMatch', lambda model: DiffusionSampler('DPM2++ 2S FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM2++ SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM2++ 2M SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ 2M SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DPM2++ 3M SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ 3M SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM2 FlowMatch', lambda model: DiffusionSampler('DPM2 FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM2a FlowMatch', lambda model: DiffusionSampler('DPM2a FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM2++ 2M FlowMatch', lambda model: DiffusionSampler('DPM2++ 2M FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM2++ 2S FlowMatch', lambda model: DiffusionSampler('DPM2++ 2S FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM2++ SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM2++ 2M SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ 2M SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
+ SamplerData('DPM2++ 3M SDE FlowMatch', lambda model: DiffusionSampler('DPM2++ 3M SDE FlowMatch', FlowMatchDPMSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('Heun', lambda model: DiffusionSampler('Heun', HeunDiscreteScheduler, model), [], {}),
- sd_samplers_common.SamplerData('Heun FlowMatch', lambda model: DiffusionSampler('Heun FlowMatch', FlowMatchHeunDiscreteScheduler, model), [], {}),
+ SamplerData('Heun', lambda model: DiffusionSampler('Heun', HeunDiscreteScheduler, model), [], {}),
+ SamplerData('Heun FlowMatch', lambda model: DiffusionSampler('Heun FlowMatch', FlowMatchHeunDiscreteScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DEIS', lambda model: DiffusionSampler('DEIS', DEISMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('SA Solver', lambda model: DiffusionSampler('SA Solver', SASolverScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DC Solver', lambda model: DiffusionSampler('DC Solver', DCSolverMultistepScheduler, model), [], {}),
- sd_samplers_common.SamplerData('VDM Solver', lambda model: DiffusionSampler('VDM Solver', VDMScheduler, model), [], {}),
+ SamplerData('DEIS', lambda model: DiffusionSampler('DEIS', DEISMultistepScheduler, model), [], {}),
+ SamplerData('SA Solver', lambda model: DiffusionSampler('SA Solver', SASolverScheduler, model), [], {}),
+ SamplerData('DC Solver', lambda model: DiffusionSampler('DC Solver', DCSolverMultistepScheduler, model), [], {}),
+ SamplerData('VDM Solver', lambda model: DiffusionSampler('VDM Solver', VDMScheduler, model), [], {}),
+ SamplerData('BDIA DDIM', lambda model: DiffusionSampler('BDIA DDIM g=0', BDIA_DDIMScheduler, model), [], {}),
- sd_samplers_common.SamplerData('PNDM', lambda model: DiffusionSampler('PNDM', PNDMScheduler, model), [], {}),
- sd_samplers_common.SamplerData('IPNDM', lambda model: DiffusionSampler('IPNDM', IPNDMScheduler, model), [], {}),
- sd_samplers_common.SamplerData('DDPM', lambda model: DiffusionSampler('DDPM', DDPMScheduler, model), [], {}),
- sd_samplers_common.SamplerData('LMSD', lambda model: DiffusionSampler('LMSD', LMSDiscreteScheduler, model), [], {}),
- sd_samplers_common.SamplerData('KDPM2', lambda model: DiffusionSampler('KDPM2', KDPM2DiscreteScheduler, model), [], {}),
- sd_samplers_common.SamplerData('KDPM2 a', lambda model: DiffusionSampler('KDPM2 a', KDPM2AncestralDiscreteScheduler, model), [], {}),
- sd_samplers_common.SamplerData('CMSI', lambda model: DiffusionSampler('CMSI', CMStochasticIterativeScheduler, model), [], {}),
+ SamplerData('PNDM', lambda model: DiffusionSampler('PNDM', PNDMScheduler, model), [], {}),
+ SamplerData('IPNDM', lambda model: DiffusionSampler('IPNDM', IPNDMScheduler, model), [], {}),
+ SamplerData('DDPM', lambda model: DiffusionSampler('DDPM', DDPMScheduler, model), [], {}),
+ SamplerData('LMSD', lambda model: DiffusionSampler('LMSD', LMSDiscreteScheduler, model), [], {}),
+ SamplerData('KDPM2', lambda model: DiffusionSampler('KDPM2', KDPM2DiscreteScheduler, model), [], {}),
+ SamplerData('KDPM2 a', lambda model: DiffusionSampler('KDPM2 a', KDPM2AncestralDiscreteScheduler, model), [], {}),
+ SamplerData('CMSI', lambda model: DiffusionSampler('CMSI', CMStochasticIterativeScheduler, model), [], {}),
- sd_samplers_common.SamplerData('LCM', lambda model: DiffusionSampler('LCM', LCMScheduler, model), [], {}),
- sd_samplers_common.SamplerData('TCD', lambda model: DiffusionSampler('TCD', TCDScheduler, model), [], {}),
+ SamplerData('LCM', lambda model: DiffusionSampler('LCM', LCMScheduler, model), [], {}),
+ SamplerData('TCD', lambda model: DiffusionSampler('TCD', TCDScheduler, model), [], {}),
+ SamplerData('UFOGen', lambda model: DiffusionSampler('UFOGen', UFOGenScheduler, model), [], {}),
- sd_samplers_common.SamplerData('Same as primary', None, [], {}),
+ SamplerData('Same as primary', None, [], {}),
]
@@ -175,14 +180,14 @@ def __init__(self, name, constructor, model, **kwargs):
orig_config = model.default_scheduler.scheduler_config
else:
orig_config = model.default_scheduler.config
- for key, value in config.get(name, {}).items(): # apply diffusers per-scheduler defaults
- self.config[key] = value
debug(f'Sampler: diffusers="{self.config}"')
debug(f'Sampler: original="{orig_config}"')
for key, value in orig_config.items(): # apply model defaults
if key in self.config:
self.config[key] = value
debug(f'Sampler: default="{self.config}"')
+ for key, value in config.get(name, {}).items(): # apply diffusers per-scheduler defaults
+ self.config[key] = value
for key, value in kwargs.items(): # apply user args, if any
if key in self.config:
self.config[key] = value
@@ -200,16 +205,20 @@ def __init__(self, name, constructor, model, **kwargs):
timesteps = re.split(',| ', shared.opts.schedulers_timesteps)
timesteps = [int(x) for x in timesteps if x.isdigit()]
if len(timesteps) == 0:
- if 'use_beta_sigmas' in self.config:
- self.config['use_beta_sigmas'] = shared.opts.schedulers_sigma == 'beta'
- if 'use_karras_sigmas' in self.config:
- self.config['use_karras_sigmas'] = shared.opts.schedulers_sigma == 'karras'
- if 'use_exponential_sigmas' in self.config:
- self.config['use_exponential_sigmas'] = shared.opts.schedulers_sigma == 'exponential'
- if 'use_lu_lambdas' in self.config:
- self.config['use_lu_lambdas'] = shared.opts.schedulers_sigma == 'lambdas'
if 'sigma_schedule' in self.config:
self.config['sigma_schedule'] = shared.opts.schedulers_sigma if shared.opts.schedulers_sigma != 'default' else None
+ if shared.opts.schedulers_sigma == 'default' and shared.sd_model_type in flow_models and 'use_flow_sigmas' in self.config:
+ self.config['use_flow_sigmas'] = True
+ elif shared.opts.schedulers_sigma == 'betas' and 'use_beta_sigmas' in self.config:
+ self.config['use_beta_sigmas'] = True
+ elif shared.opts.schedulers_sigma == 'karras' and 'use_karras_sigmas' in self.config:
+ self.config['use_karras_sigmas'] = True
+ elif shared.opts.schedulers_sigma == 'flowmatch' and 'use_flow_sigmas' in self.config:
+ self.config['use_flow_sigmas'] = True
+ elif shared.opts.schedulers_sigma == 'exponential' and 'use_exponential_sigmas' in self.config:
+ self.config['use_exponential_sigmas'] = True
+ elif shared.opts.schedulers_sigma == 'lambdas' and 'use_lu_lambdas' in self.config:
+ self.config['use_lu_lambdas'] = True
else:
pass # timesteps are set using set_timesteps in set_pipeline_args
@@ -236,7 +245,7 @@ def __init__(self, name, constructor, model, **kwargs):
if 'use_dynamic_shifting' in self.config:
if 'Flux' in model.__class__.__name__:
self.config['use_dynamic_shifting'] = shared.opts.schedulers_dynamic_shift
- if 'use_beta_sigmas' in self.config:
+ if 'use_beta_sigmas' in self.config and 'sigma_schedule' in self.config:
self.config['use_beta_sigmas'] = 'StableDiffusion3' in model.__class__.__name__
if 'rescale_betas_zero_snr' in self.config:
self.config['rescale_betas_zero_snr'] = shared.opts.schedulers_rescale_betas
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
index f730bdb74..deb0b24b0 100644
--- a/modules/sd_unet.py
+++ b/modules/sd_unet.py
@@ -36,7 +36,7 @@ def load_unet(model):
model.prior_pipe.text_encoder = prior_text_encoder.to(devices.device, dtype=devices.dtype)
elif "Flux" in model.__class__.__name__ or "StableDiffusion3" in model.__class__.__name__:
loaded_unet = shared.opts.sd_unet
- sd_models.load_diffuser() # TODO forcing reloading entire model as loading transformers only leads to massive memory usage
+ sd_models.load_diffuser() # TODO model load: force-reloading entire model as loading transformers only leads to massive memory usage
"""
from modules.model_flux import load_transformer
transformer = load_transformer(unet_dict[shared.opts.sd_unet])
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
index 4d213ad48..a1959817c 100644
--- a/modules/sd_vae_taesd.py
+++ b/modules/sd_vae_taesd.py
@@ -169,6 +169,9 @@ def decode(latents):
if vae is None:
return latents
try:
+ size = max(latents.shape[-1], latents.shape[-2])
+ if size > 256:
+ return latents
with devices.inference_context():
latents = latents.detach().clone().to(devices.device, dtype)
if len(latents.shape) == 3:
diff --git a/modules/segmoe/segmoe_model.py b/modules/segmoe/segmoe_model.py
index 4b96527be..a542c1fbb 100644
--- a/modules/segmoe/segmoe_model.py
+++ b/modules/segmoe/segmoe_model.py
@@ -136,7 +136,7 @@ def __init__(self, config_or_path, **kwargs) -> Any:
memory_format=torch.channels_last,
)
- def to(self, *args, **kwargs): # TODO added no-op to avoid error
+ def to(self, *args, **kwargs):
self.pipe.to(*args, **kwargs)
def load_from_scratch(self, config: str, **kwargs) -> None:
@@ -202,7 +202,6 @@ def load_from_scratch(self, config: str, **kwargs) -> None:
self.config["down_idx_start"] = self.down_idx_start
self.config["down_idx_end"] = self.down_idx_end
- # TODO: Add Support for Scheduler Selection
self.pipe.scheduler = DDPMScheduler.from_config(self.pipe.scheduler.config)
# Load Experts
@@ -242,7 +241,6 @@ def load_from_scratch(self, config: str, **kwargs) -> None:
**kwargs,
)
- # TODO: Add Support for Scheduler Selection
expert.scheduler = DDPMScheduler.from_config(
expert.scheduler.config
)
diff --git a/modules/shared.py b/modules/shared.py
index a89cbbc95..31429e076 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -16,11 +16,11 @@
import orjson
import diffusers
from rich.console import Console
-from modules import errors, devices, shared_items, shared_state, cmd_args, theme, history
+from modules import errors, devices, shared_items, shared_state, cmd_args, theme, history, files_cache
from modules.paths import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611
from modules.dml import memory_providers, default_memory_provider, directml_do_hijack
from modules.onnx_impl import initialize_onnx, execution_providers
-from modules.memstats import memory_stats
+from modules.memstats import memory_stats, ram_stats # pylint: disable=unused-import
from modules.ui_components import DropdownEditable
import modules.interrogate
import modules.memmon
@@ -132,7 +132,8 @@ def readfile(filename, silent=False, lock=False):
# data = json.loads(data)
t1 = time.time()
if not silent:
- log.debug(f'Read: file="{filename}" json={len(data)} bytes={os.path.getsize(filename)} time={t1-t0:.3f}')
+ fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
+ log.debug(f'Read: file="{filename}" json={len(data)} bytes={os.path.getsize(filename)} time={t1-t0:.3f} fn={fn}')
except FileNotFoundError as err:
log.debug(f'Reading failed: {filename} {err}')
except Exception as err:
@@ -237,6 +238,8 @@ def default(obj):
mem_stat = memory_stats()
gpu_memory = mem_stat['gpu']['total'] if "gpu" in mem_stat else 0
native = backend == Backend.DIFFUSERS
+if not files_cache.do_cache_folders:
+ log.warning('File cache disabled: ')
class OptionInfo:
@@ -363,7 +366,7 @@ def list_samplers():
def temp_disable_extensions():
disable_safe = ['sd-webui-controlnet', 'multidiffusion-upscaler-for-automatic1111', 'a1111-sd-webui-lycoris', 'sd-webui-agent-scheduler', 'clip-interrogator-ext', 'stable-diffusion-webui-rembg', 'sd-extension-chainner', 'stable-diffusion-webui-images-browser']
- disable_diffusers = ['sd-webui-controlnet', 'multidiffusion-upscaler-for-automatic1111', 'a1111-sd-webui-lycoris', 'sd-webui-animatediff']
+ disable_diffusers = ['sd-webui-controlnet', 'multidiffusion-upscaler-for-automatic1111', 'a1111-sd-webui-lycoris', 'sd-webui-animatediff', 'Lora']
disable_themes = ['sd-webui-lobe-theme', 'cozy-nest', 'sdnext-modernui']
disable_original = []
disabled = []
@@ -431,15 +434,15 @@ def get_default_modes():
cmd_opts.lowvram = True
default_offload_mode = "sequential"
log.info(f"Device detect: memory={gpu_memory:.1f} optimization=lowvram")
- elif gpu_memory <= 8:
- cmd_opts.medvram = True
- default_offload_mode = "model"
- log.info(f"Device detect: memory={gpu_memory:.1f} optimization=medvram")
+ # elif gpu_memory <= 8:
+ # cmd_opts.medvram = True
+ # default_offload_mode = "model"
+ # log.info(f"Device detect: memory={gpu_memory:.1f} optimization=medvram")
else:
- default_offload_mode = "none"
- log.info(f"Device detect: memory={gpu_memory:.1f} optimization=none")
+ default_offload_mode = "balanced"
+ log.info(f"Device detect: memory={gpu_memory:.1f} optimization=balanced")
elif cmd_opts.medvram:
- default_offload_mode = "model"
+ default_offload_mode = "balanced"
elif cmd_opts.lowvram:
default_offload_mode = "sequential"
@@ -449,16 +452,16 @@ def get_default_modes():
default_cross_attention = "Scaled-Dot-Product" if native else "Doggettx's"
elif devices.backend == "mps":
default_cross_attention = "Scaled-Dot-Product" if native else "Doggettx's"
- else: # cuda, rocm, ipex, openvino
- default_cross_attention ="Scaled-Dot-Product"
+ else: # cuda, rocm, zluda, ipex, openvino
+ default_cross_attention = "Scaled-Dot-Product"
if devices.backend == "rocm":
default_sdp_options = ['Memory attention', 'Math attention']
elif devices.backend == "zluda":
- default_sdp_options = ['Math attention']
+ default_sdp_options = ['Math attention', 'Dynamic attention']
else:
default_sdp_options = ['Flash attention', 'Memory attention', 'Math attention']
- if (cmd_opts.lowvram or cmd_opts.medvram) and ('Flash attention' not in default_sdp_options):
+ if (cmd_opts.lowvram or cmd_opts.medvram) and ('Flash attention' not in default_sdp_options and 'Dynamic attention' not in default_sdp_options):
default_sdp_options.append('Dynamic attention')
return default_offload_mode, default_cross_attention, default_sdp_options
@@ -466,165 +469,165 @@ def get_default_modes():
startup_offload_mode, startup_cross_attention, startup_sdp_options = get_default_modes()
-options_templates.update(options_section(('sd', "Execution & Models"), {
+options_templates.update(options_section(('sd', "Models & Loading"), {
"sd_backend": OptionInfo(default_backend, "Execution backend", gr.Radio, {"choices": ["diffusers", "original"] }),
+ "diffusers_pipeline": OptionInfo('Autodetect', 'Model pipeline', gr.Dropdown, lambda: {"choices": list(shared_items.get_pipelines()), "visible": native}),
"sd_model_checkpoint": OptionInfo(default_checkpoint, "Base model", DropdownEditable, lambda: {"choices": list_checkpoint_titles()}, refresh=refresh_checkpoints),
"sd_model_refiner": OptionInfo('None', "Refiner model", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_titles()}, refresh=refresh_checkpoints),
- "sd_vae": OptionInfo("Automatic", "VAE model", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
"sd_unet": OptionInfo("None", "UNET model", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list),
- "sd_text_encoder": OptionInfo('None', "Text encoder model", gr.Dropdown, lambda: {"choices": shared_items.sd_te_items()}, refresh=shared_items.refresh_te_list),
- "sd_model_dict": OptionInfo('None', "Use separate base dict", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_titles()}, refresh=refresh_checkpoints),
+ "latent_history": OptionInfo(16, "Latent history size", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
+
+ "offload_sep": OptionInfo("Model Offloading
", "", gr.HTML),
+ "diffusers_move_base": OptionInfo(False, "Move base model to CPU when using refiner", gr.Checkbox, {"visible": False }),
+ "diffusers_move_unet": OptionInfo(False, "Move base model to CPU when using VAE", gr.Checkbox, {"visible": False }),
+ "diffusers_move_refiner": OptionInfo(False, "Move refiner model to CPU when not in use", gr.Checkbox, {"visible": False }),
+ "diffusers_extract_ema": OptionInfo(False, "Use model EMA weights when possible", gr.Checkbox, {"visible": False }),
+ "diffusers_offload_mode": OptionInfo(startup_offload_mode, "Model offload mode", gr.Radio, {"choices": ['none', 'balanced', 'model', 'sequential']}),
+ "diffusers_offload_min_gpu_memory": OptionInfo(0.25, "Balanced offload GPU low watermark", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01 }),
+ "diffusers_offload_max_gpu_memory": OptionInfo(0.70, "Balanced offload GPU high watermark", gr.Slider, {"minimum": 0.1, "maximum": 1, "step": 0.01 }),
+ "diffusers_offload_max_cpu_memory": OptionInfo(0.90, "Balanced offload CPU high watermark", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": False }),
+
+ "advanced_sep": OptionInfo("Advanced Options
", "", gr.HTML),
"sd_checkpoint_autoload": OptionInfo(True, "Model autoload on start"),
"sd_checkpoint_autodownload": OptionInfo(True, "Model auto-download on demand"),
- "sd_textencoder_cache": OptionInfo(True, "Cache text encoder results", gr.Checkbox, {"visible": False}),
- "sd_textencoder_cache_size": OptionInfo(4, "Text encoder results LRU cache size", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "stream_load": OptionInfo(False, "Load models using stream loading method", gr.Checkbox, {"visible": not native }),
+ "stream_load": OptionInfo(False, "Model load using streams", gr.Checkbox),
+ "diffusers_eval": OptionInfo(True, "Force model eval", gr.Checkbox, {"visible": False }),
+ "diffusers_to_gpu": OptionInfo(False, "Load model directly to GPU"),
+ "disable_accelerate": OptionInfo(False, "Disable accelerate", gr.Checkbox, {"visible": False }),
+ "sd_model_dict": OptionInfo('None', "Use separate base dict", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_titles()}, refresh=refresh_checkpoints),
+ "sd_checkpoint_cache": OptionInfo(0, "Cached models", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1, "visible": not native }),
+}))
+
+options_templates.update(options_section(('vae_encoder', "Variable Auto Encoder"), {
+ "sd_vae": OptionInfo("Automatic", "VAE model", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
+ "diffusers_vae_upcast": OptionInfo("default", "VAE upcasting", gr.Radio, {"choices": ['default', 'true', 'false']}),
+ "no_half_vae": OptionInfo(False if not cmd_opts.use_openvino else True, "Full precision (--no-half-vae)"),
+ "diffusers_vae_slicing": OptionInfo(True, "VAE slicing", gr.Checkbox, {"visible": native}),
+ "diffusers_vae_tiling": OptionInfo(cmd_opts.lowvram or cmd_opts.medvram, "VAE tiling", gr.Checkbox, {"visible": native}),
+ "sd_vae_sliced_encode": OptionInfo(False, "VAE sliced encode", gr.Checkbox, {"visible": not native}),
+ "nan_skip": OptionInfo(False, "Skip Generation if NaN found in latents", gr.Checkbox),
+ "rollback_vae": OptionInfo(False, "Attempt VAE roll back for NaN values"),
+}))
+
+options_templates.update(options_section(('text_encoder', "Text Encoder"), {
+ "sd_text_encoder": OptionInfo('None', "Text encoder model", gr.Dropdown, lambda: {"choices": shared_items.sd_te_items()}, refresh=shared_items.refresh_te_list),
+ "prompt_attention": OptionInfo("native", "Prompt attention parser", gr.Radio, {"choices": ["native", "compel", "xhinker", "a1111", "fixed"] }),
"prompt_mean_norm": OptionInfo(False, "Prompt attention normalization", gr.Checkbox),
+ "sd_textencoder_cache": OptionInfo(True, "Cache text encoder results", gr.Checkbox, {"visible": False}),
+ "sd_textencoder_cache_size": OptionInfo(4, "Text encoder cache size", gr.Slider, {"minimum": 0, "maximum": 16, "step": 1}),
"comma_padding_backtrack": OptionInfo(20, "Prompt padding", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1, "visible": not native }),
- "prompt_attention": OptionInfo("native", "Prompt attention parser", gr.Radio, {"choices": ["native", "compel", "xhinker", "a1111", "fixed"] }),
- "latent_history": OptionInfo(16, "Latent history size", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
- "sd_checkpoint_cache": OptionInfo(0, "Cached models", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1, "visible": not native }),
+ "diffusers_zeros_prompt_pad": OptionInfo(False, "Use zeros for prompt padding", gr.Checkbox),
+ "diffusers_pooled": OptionInfo("default", "Diffusers SDXL pooled embeds", gr.Radio, {"choices": ['default', 'weighted']}),
}))
options_templates.update(options_section(('cuda', "Compute Settings"), {
- "math_sep": OptionInfo("Execution precision
", "", gr.HTML),
+ "math_sep": OptionInfo("Execution Precision
", "", gr.HTML),
"precision": OptionInfo("Autocast", "Precision type", gr.Radio, {"choices": ["Autocast", "Full"]}),
"cuda_dtype": OptionInfo("Auto", "Device precision type", gr.Radio, {"choices": ["Auto", "FP32", "FP16", "BF16"]}),
+ "no_half": OptionInfo(False if not cmd_opts.use_openvino else True, "Full precision (--no-half)", None, None, None),
+ "upcast_sampling": OptionInfo(False if sys.platform != "darwin" else True, "Upcast sampling", gr.Checkbox, {"visible": not native}),
+ "upcast_attn": OptionInfo(False, "Upcast attention layer", gr.Checkbox, {"visible": not native}),
+ "cuda_cast_unet": OptionInfo(False, "Fixed UNet precision", gr.Checkbox, {"visible": not native}),
- "model_sep": OptionInfo("Model options
", "", gr.HTML),
- "no_half": OptionInfo(False if not cmd_opts.use_openvino else True, "Full precision for model (--no-half)", None, None, None),
- "no_half_vae": OptionInfo(False if not cmd_opts.use_openvino else True, "Full precision for VAE (--no-half-vae)"),
- "upcast_sampling": OptionInfo(False if sys.platform != "darwin" else True, "Upcast sampling"),
- "upcast_attn": OptionInfo(False, "Upcast attention layer"),
- "cuda_cast_unet": OptionInfo(False, "Fixed UNet precision"),
- "nan_skip": OptionInfo(False, "Skip Generation if NaN found in latents", gr.Checkbox),
- "rollback_vae": OptionInfo(False, "Attempt VAE roll back for NaN values"),
+ "generator_sep": OptionInfo("Noise Options
", "", gr.HTML),
+ "diffusers_generator_device": OptionInfo("GPU", "Generator device", gr.Radio, {"choices": ["GPU", "CPU", "Unset"]}),
"cross_attention_sep": OptionInfo("Cross Attention
", "", gr.HTML),
- "cross_attention_optimization": OptionInfo(startup_cross_attention, "Attention optimization method", gr.Radio, lambda: {"choices": shared_items.list_crossattention(native) }),
- "sdp_options": OptionInfo(startup_sdp_options, "SDP options", gr.CheckboxGroup, {"choices": ['Flash attention', 'Memory attention', 'Math attention', 'Dynamic attention', 'Sage attention'] }),
+ "cross_attention_optimization": OptionInfo(startup_cross_attention, "Attention optimization method", gr.Radio, lambda: {"choices": shared_items.list_crossattention(native)}),
+ "sdp_options": OptionInfo(startup_sdp_options, "SDP options", gr.CheckboxGroup, {"choices": ['Flash attention', 'Memory attention', 'Math attention', 'Dynamic attention', 'Sage attention'], "visible": native}),
"xformers_options": OptionInfo(['Flash attention'], "xFormers options", gr.CheckboxGroup, {"choices": ['Flash attention'] }),
"dynamic_attention_slice_rate": OptionInfo(4, "Dynamic Attention slicing rate in GB", gr.Slider, {"minimum": 0.1, "maximum": gpu_memory, "step": 0.1, "visible": native}),
"sub_quad_sep": OptionInfo("Sub-quadratic options
", "", gr.HTML, {"visible": not native}),
"sub_quad_q_chunk_size": OptionInfo(512, "Attention query chunk size", gr.Slider, {"minimum": 16, "maximum": 8192, "step": 8, "visible": not native}),
"sub_quad_kv_chunk_size": OptionInfo(512, "Attention kv chunk size", gr.Slider, {"minimum": 0, "maximum": 8192, "step": 8, "visible": not native}),
"sub_quad_chunk_threshold": OptionInfo(80, "Attention chunking threshold", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1, "visible": not native}),
+}))
- "other_sep": OptionInfo("Execution options
", "", gr.HTML),
- "opt_channelslast": OptionInfo(False, "Use channels last "),
- "cudnn_deterministic": OptionInfo(False, "Use deterministic mode"),
- "cudnn_benchmark": OptionInfo(False, "Full-depth cuDNN benchmark feature"),
+options_templates.update(options_section(('backends', "Backend Settings"), {
+ "other_sep": OptionInfo("Torch Options
", "", gr.HTML),
+ "opt_channelslast": OptionInfo(False, "Channels last "),
+ "cudnn_deterministic": OptionInfo(False, "Deterministic mode"),
+ "cudnn_benchmark": OptionInfo(False, "Full-depth cuDNN benchmark"),
"diffusers_fuse_projections": OptionInfo(False, "Fused projections"),
- "torch_expandable_segments": OptionInfo(False, "Torch expandable segments"),
- "cuda_mem_fraction": OptionInfo(0.0, "Torch memory limit", gr.Slider, {"minimum": 0, "maximum": 2.0, "step": 0.05}),
- "torch_gc_threshold": OptionInfo(80, "Torch memory threshold for GC", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}),
- "torch_malloc": OptionInfo("native", "Torch memory allocator", gr.Radio, {"choices": ['native', 'cudaMallocAsync'] }),
+ "torch_expandable_segments": OptionInfo(False, "Expandable segments"),
+ "cuda_mem_fraction": OptionInfo(0.0, "Memory limit", gr.Slider, {"minimum": 0, "maximum": 2.0, "step": 0.05}),
+ "torch_gc_threshold": OptionInfo(70, "GC threshold", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}),
+ "inference_mode": OptionInfo("no-grad", "Inference mode", gr.Radio, {"choices": ["no-grad", "inference-mode", "none"]}),
+ "torch_malloc": OptionInfo("native", "Memory allocator", gr.Radio, {"choices": ['native', 'cudaMallocAsync'] }),
+
+ "onnx_sep": OptionInfo("ONNX
", "", gr.HTML),
+ "onnx_execution_provider": OptionInfo(execution_providers.get_default_execution_provider().value, 'ONNX Execution Provider', gr.Dropdown, lambda: {"choices": execution_providers.available_execution_providers }),
+ "onnx_cpu_fallback": OptionInfo(True, 'ONNX allow fallback to CPU'),
+ "onnx_cache_converted": OptionInfo(True, 'ONNX cache converted models'),
+ "onnx_unload_base": OptionInfo(False, 'ONNX unload base model when processing refiner'),
- "cuda_compile_sep": OptionInfo("Model Compile
", "", gr.HTML),
- "cuda_compile": OptionInfo([] if not cmd_opts.use_openvino else ["Model", "VAE"], "Compile Model", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "Upscaler"]}),
- "cuda_compile_backend": OptionInfo("none" if not cmd_opts.use_openvino else "openvino_fx", "Model compile backend", gr.Radio, {"choices": ['none', 'inductor', 'cudagraphs', 'aot_ts_nvfuser', 'hidet', 'migraphx', 'ipex', 'onediff', 'stable-fast', 'deep-cache', 'olive-ai', 'openvino_fx']}),
- "cuda_compile_mode": OptionInfo("default", "Model compile mode", gr.Radio, {"choices": ['default', 'reduce-overhead', 'max-autotune', 'max-autotune-no-cudagraphs']}),
- "cuda_compile_fullgraph": OptionInfo(True if not cmd_opts.use_openvino else False, "Model compile fullgraph"),
- "cuda_compile_precompile": OptionInfo(False, "Model compile precompile"),
- "cuda_compile_verbose": OptionInfo(False, "Model compile verbose mode"),
- "cuda_compile_errors": OptionInfo(True, "Model compile suppress errors"),
- "deep_cache_interval": OptionInfo(3, "DeepCache cache interval", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
+ "olive_sep": OptionInfo("Olive
", "", gr.HTML),
+ "olive_float16": OptionInfo(True, 'Olive use FP16 on optimization'),
+ "olive_vae_encoder_float32": OptionInfo(False, 'Olive force FP32 for VAE Encoder'),
+ "olive_static_dims": OptionInfo(True, 'Olive use static dimensions'),
+ "olive_cache_optimized": OptionInfo(True, 'Olive cache optimized models'),
"ipex_sep": OptionInfo("IPEX
", "", gr.HTML, {"visible": devices.backend == "ipex"}),
- "ipex_optimize": OptionInfo([], "IPEX Optimize for Intel GPUs", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "Upscaler"], "visible": devices.backend == "ipex"}),
+ "ipex_optimize": OptionInfo([], "IPEX Optimize", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "Upscaler"], "visible": devices.backend == "ipex"}),
"openvino_sep": OptionInfo("OpenVINO
", "", gr.HTML, {"visible": cmd_opts.use_openvino}),
"openvino_devices": OptionInfo([], "OpenVINO devices to use", gr.CheckboxGroup, {"choices": get_openvino_device_list() if cmd_opts.use_openvino else [], "visible": cmd_opts.use_openvino}), # pylint: disable=E0606
"openvino_accuracy": OptionInfo("performance", "OpenVINO accuracy mode", gr.Radio, {"choices": ['performance', 'accuracy'], "visible": cmd_opts.use_openvino}),
- "openvino_disable_model_caching": OptionInfo(False, "OpenVINO disable model caching", gr.Checkbox, {"visible": cmd_opts.use_openvino}),
+ "openvino_disable_model_caching": OptionInfo(True, "OpenVINO disable model caching", gr.Checkbox, {"visible": cmd_opts.use_openvino}),
"openvino_disable_memory_cleanup": OptionInfo(True, "OpenVINO disable memory cleanup after compile", gr.Checkbox, {"visible": cmd_opts.use_openvino}),
"directml_sep": OptionInfo("DirectML
", "", gr.HTML, {"visible": devices.backend == "directml"}),
"directml_memory_provider": OptionInfo(default_memory_provider, 'DirectML memory stats provider', gr.Radio, {"choices": memory_providers, "visible": devices.backend == "directml"}),
"directml_catch_nan": OptionInfo(False, "DirectML retry ops for NaN", gr.Checkbox, {"visible": devices.backend == "directml"}),
-
- "olive_sep": OptionInfo("Olive
", "", gr.HTML),
- "olive_float16": OptionInfo(True, 'Olive use FP16 on optimization'),
- "olive_vae_encoder_float32": OptionInfo(False, 'Olive force FP32 for VAE Encoder'),
- "olive_static_dims": OptionInfo(True, 'Olive use static dimensions'),
- "olive_cache_optimized": OptionInfo(True, 'Olive cache optimized models'),
-}))
-
-options_templates.update(options_section(('diffusers', "Diffusers Settings"), {
- "diffusers_pipeline": OptionInfo('Autodetect', 'Diffusers pipeline', gr.Dropdown, lambda: {"choices": list(shared_items.get_pipelines()) }),
- "diffuser_cache_config": OptionInfo(True, "Use cached model config when available"),
- "diffusers_move_base": OptionInfo(False, "Move base model to CPU when using refiner"),
- "diffusers_move_unet": OptionInfo(False, "Move base model to CPU when using VAE"),
- "diffusers_move_refiner": OptionInfo(False, "Move refiner model to CPU when not in use"),
- "diffusers_extract_ema": OptionInfo(False, "Use model EMA weights when possible"),
- "diffusers_generator_device": OptionInfo("GPU", "Generator device", gr.Radio, {"choices": ["GPU", "CPU", "Unset"]}),
- "diffusers_offload_mode": OptionInfo(startup_offload_mode, "Model offload mode", gr.Radio, {"choices": ['none', 'balanced', 'model', 'sequential']}),
- "diffusers_offload_max_gpu_memory": OptionInfo(round(gpu_memory * 0.75, 1), "Max GPU memory for balanced offload mode in GB", gr.Slider, {"minimum": 0, "maximum": gpu_memory, "step": 0.01,}),
- "diffusers_offload_max_cpu_memory": OptionInfo(round(cpu_memory * 0.75, 1), "Max CPU memory for balanced offload mode in GB", gr.Slider, {"minimum": 0, "maximum": cpu_memory, "step": 0.01,}),
- "diffusers_vae_upcast": OptionInfo("default", "VAE upcasting", gr.Radio, {"choices": ['default', 'true', 'false']}),
- "diffusers_vae_slicing": OptionInfo(True, "VAE slicing"),
- "diffusers_vae_tiling": OptionInfo(cmd_opts.lowvram or cmd_opts.medvram, "VAE tiling"),
- "diffusers_model_load_variant": OptionInfo("default", "Preferred Model variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}),
- "diffusers_vae_load_variant": OptionInfo("default", "Preferred VAE variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}),
- "custom_diffusers_pipeline": OptionInfo('', 'Load custom Diffusers pipeline'),
- "diffusers_eval": OptionInfo(True, "Force model eval"),
- "diffusers_to_gpu": OptionInfo(False, "Load model directly to GPU"),
- "disable_accelerate": OptionInfo(False, "Disable accelerate"),
- "diffusers_pooled": OptionInfo("default", "Diffusers SDXL pooled embeds", gr.Radio, {"choices": ['default', 'weighted']}),
- "diffusers_zeros_prompt_pad": OptionInfo(False, "Use zeros for prompt padding", gr.Checkbox),
- "huggingface_token": OptionInfo('', 'HuggingFace token'),
- "enable_linfusion": OptionInfo(False, "Apply LinFusion distillation on load"),
-
- "onnx_sep": OptionInfo("ONNX Runtime
", "", gr.HTML),
- "onnx_execution_provider": OptionInfo(execution_providers.get_default_execution_provider().value, 'Execution Provider', gr.Dropdown, lambda: {"choices": execution_providers.available_execution_providers }),
- "onnx_cpu_fallback": OptionInfo(True, 'ONNX allow fallback to CPU'),
- "onnx_cache_converted": OptionInfo(True, 'ONNX cache converted models'),
- "onnx_unload_base": OptionInfo(False, 'ONNX unload base model when processing refiner'),
}))
options_templates.update(options_section(('quantization', "Quantization Settings"), {
- "bnb_quantization": OptionInfo([], "BnB quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": native}),
- "bnb_quantization_type": OptionInfo("nf4", "BnB quantization type", gr.Radio, {"choices": ['nf4', 'fp8', 'fp4'], "visible": native}),
- "bnb_quantization_storage": OptionInfo("uint8", "BnB quantization storage", gr.Radio, {"choices": ["float16", "float32", "int8", "uint8", "float64", "bfloat16"], "visible": native}),
- "optimum_quanto_weights": OptionInfo([], "Optimum.quanto quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "ControlNet"], "visible": native}),
- "optimum_quanto_weights_type": OptionInfo("qint8", "Optimum.quanto quantization type", gr.Radio, {"choices": ['qint8', 'qfloat8_e4m3fn', 'qfloat8_e5m2', 'qint4', 'qint2'], "visible": native}),
- "optimum_quanto_activations_type": OptionInfo("none", "Optimum.quanto quantization activations ", gr.Radio, {"choices": ['none', 'qint8', 'qfloat8_e4m3fn', 'qfloat8_e5m2'], "visible": native}),
- "torchao_quantization": OptionInfo([], "TorchAO quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": native}),
- "torchao_quantization_type": OptionInfo("int8", "TorchAO quantization type", gr.Radio, {"choices": ["int8+act", "int8", "int4", "fp8+act", "fp8", "fpx"], "visible": native}),
- "nncf_compress_weights": OptionInfo([], "NNCF compression enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "ControlNet"], "visible": native}),
- "nncf_compress_weights_mode": OptionInfo("INT8", "NNCF compress mode", gr.Radio, {"choices": ['INT8', 'INT8_SYM', 'INT4_ASYM', 'INT4_SYM', 'NF4'] if cmd_opts.use_openvino else ['INT8']}),
- "nncf_compress_weights_raito": OptionInfo(1.0, "NNCF compress ratio", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": cmd_opts.use_openvino}),
- "nncf_quantize": OptionInfo([], "NNCF OpenVINO quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": cmd_opts.use_openvino}),
- "nncf_quant_mode": OptionInfo("INT8", "NNCF OpenVINO quantization mode", gr.Radio, {"choices": ['INT8', 'FP8_E4M3', 'FP8_E5M2'], "visible": cmd_opts.use_openvino}),
-
- "quant_shuffle_weights": OptionInfo(False, "Shuffle the weights between GPU and CPU when quantizing", gr.Checkbox, {"visible": native}),
+ "bnb_sep": OptionInfo("BitsAndBytes
", "", gr.HTML),
+ "bnb_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": native}),
+ "bnb_quantization_type": OptionInfo("nf4", "Quantization type", gr.Dropdown, {"choices": ['nf4', 'fp8', 'fp4'], "visible": native}),
+ "bnb_quantization_storage": OptionInfo("uint8", "Backend storage", gr.Dropdown, {"choices": ["float16", "float32", "int8", "uint8", "float64", "bfloat16"], "visible": native}),
+ "optimum_quanto_sep": OptionInfo("Optimum Quanto
", "", gr.HTML),
+ "optimum_quanto_weights": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "ControlNet"], "visible": native}),
+ "optimum_quanto_weights_type": OptionInfo("qint8", "Quantization weights type", gr.Dropdown, {"choices": ['qint8', 'qfloat8_e4m3fn', 'qfloat8_e5m2', 'qint4', 'qint2'], "visible": native}),
+ "optimum_quanto_activations_type": OptionInfo("none", "Quantization activations type ", gr.Dropdown, {"choices": ['none', 'qint8', 'qfloat8_e4m3fn', 'qfloat8_e5m2'], "visible": native}),
+ "torchao_sep": OptionInfo("TorchAO
", "", gr.HTML),
+ "torchao_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": native}),
+ "torchao_quantization_mode": OptionInfo("pre", "Quantization mode", gr.Dropdown, {"choices": ['pre', 'post'], "visible": native}),
+ "torchao_quantization_type": OptionInfo("int8_weight_only", "Quantization type", gr.Dropdown, {"choices": ['int4_weight_only', 'int8_dynamic_activation_int4_weight', 'int8_weight_only', 'int8_dynamic_activation_int8_weight', 'float8_weight_only', 'float8_dynamic_activation_float8_weight', 'float8_static_activation_float8_weight'], "visible": native}),
+ "nncf_sep": OptionInfo("NNCF
", "", gr.HTML),
+ "nncf_compress_weights": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "ControlNet"], "visible": native}),
+ "nncf_compress_weights_mode": OptionInfo("INT8", "Quantization type", gr.Dropdown, {"choices": ['INT8', 'INT8_SYM', 'INT4_ASYM', 'INT4_SYM', 'NF4'] if cmd_opts.use_openvino else ['INT8']}),
+ "nncf_compress_weights_raito": OptionInfo(1.0, "Compress ratio", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": cmd_opts.use_openvino}),
+ "nncf_quantize": OptionInfo([], "OpenVINO enabled", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": cmd_opts.use_openvino}),
+ "nncf_quant_mode": OptionInfo("INT8", "OpenVINO mode", gr.Dropdown, {"choices": ['INT8', 'FP8_E4M3', 'FP8_E5M2'], "visible": cmd_opts.use_openvino}),
+ "quant_shuffle_weights": OptionInfo(False, "Shuffle weights", gr.Checkbox, {"visible": native}),
}))
-options_templates.update(options_section(('advanced', "Inference Settings"), {
- "token_merging_sep": OptionInfo("Token merging
", "", gr.HTML),
+options_templates.update(options_section(('advanced', "Pipeline Modifiers"), {
+ "token_merging_sep": OptionInfo("Token Merging
", "", gr.HTML),
"token_merging_method": OptionInfo("None", "Token merging method", gr.Radio, {"choices": ['None', 'ToMe', 'ToDo']}),
"tome_ratio": OptionInfo(0.0, "ToMe token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05}),
"todo_ratio": OptionInfo(0.0, "ToDo token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05}),
"freeu_sep": OptionInfo("FreeU
", "", gr.HTML),
"freeu_enabled": OptionInfo(False, "FreeU"),
- "freeu_b1": OptionInfo(1.2, "1st stage backbone factor", gr.Slider, {"minimum": 1.0, "maximum": 2.0, "step": 0.01}),
- "freeu_b2": OptionInfo(1.4, "2nd stage backbone factor", gr.Slider, {"minimum": 1.0, "maximum": 2.0, "step": 0.01}),
- "freeu_s1": OptionInfo(0.9, "1st stage skip factor", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "freeu_s2": OptionInfo(0.2, "2nd stage skip factor", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "freeu_b1": OptionInfo(1.2, "1st stage backbone", gr.Slider, {"minimum": 1.0, "maximum": 2.0, "step": 0.01}),
+ "freeu_b2": OptionInfo(1.4, "2nd stage backbone", gr.Slider, {"minimum": 1.0, "maximum": 2.0, "step": 0.01}),
+ "freeu_s1": OptionInfo(0.9, "1st stage skip", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "freeu_s2": OptionInfo(0.2, "2nd stage skip", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"pag_sep": OptionInfo("Perturbed-Attention Guidance
", "", gr.HTML),
"pag_apply_layers": OptionInfo("m0", "PAG layer names"),
"hypertile_sep": OptionInfo("HyperTile
", "", gr.HTML),
- "hypertile_hires_only": OptionInfo(False, "HyperTile hires pass only"),
- "hypertile_unet_enabled": OptionInfo(False, "HyperTile UNet"),
- "hypertile_unet_tile": OptionInfo(0, "HyperTile UNet tile size", gr.Slider, {"minimum": 0, "maximum": 1024, "step": 8}),
- "hypertile_unet_swap_size": OptionInfo(1, "HyperTile UNet swap size", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
- "hypertile_unet_depth": OptionInfo(0, "HyperTile UNet depth", gr.Slider, {"minimum": 0, "maximum": 4, "step": 1}),
- "hypertile_vae_enabled": OptionInfo(False, "HyperTile VAE", gr.Checkbox),
- "hypertile_vae_tile": OptionInfo(128, "HyperTile VAE tile size", gr.Slider, {"minimum": 0, "maximum": 1024, "step": 8}),
- "hypertile_vae_swap_size": OptionInfo(1, "HyperTile VAE swap size", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
+ "hypertile_hires_only": OptionInfo(False, "HiRes pass only"),
+ "hypertile_unet_enabled": OptionInfo(False, "UNet Enabled"),
+ "hypertile_unet_tile": OptionInfo(0, "UNet tile size", gr.Slider, {"minimum": 0, "maximum": 1024, "step": 8}),
+ "hypertile_unet_swap_size": OptionInfo(1, "UNet swap size", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
+ "hypertile_unet_depth": OptionInfo(0, "UNet depth", gr.Slider, {"minimum": 0, "maximum": 4, "step": 1}),
+ "hypertile_vae_enabled": OptionInfo(False, "VAE Enabled", gr.Checkbox),
+ "hypertile_vae_tile": OptionInfo(128, "VAE tile size", gr.Slider, {"minimum": 0, "maximum": 1024, "step": 8}),
+ "hypertile_vae_swap_size": OptionInfo(1, "VAE swap size", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
"hidiffusion_sep": OptionInfo("HiDiffusion
", "", gr.HTML),
"hidiffusion_raunet": OptionInfo(True, "Apply RAU-Net"),
@@ -633,16 +636,28 @@ def get_default_modes():
"hidiffusion_t1": OptionInfo(-1, "Override T1 ratio", gr.Slider, {"minimum": -1, "maximum": 1.0, "step": 0.05}),
"hidiffusion_t2": OptionInfo(-1, "Override T2 ratio", gr.Slider, {"minimum": -1, "maximum": 1.0, "step": 0.05}),
+ "linfusion_sep": OptionInfo("Batch
", "", gr.HTML),
+ "enable_linfusion": OptionInfo(False, "Apply LinFusion distillation on load"),
+
"inference_batch_sep": OptionInfo("Batch
", "", gr.HTML),
"sequential_seed": OptionInfo(True, "Batch mode uses sequential seeds"),
"batch_frame_mode": OptionInfo(False, "Parallel process images in batch"),
- "inference_other_sep": OptionInfo("Other
", "", gr.HTML),
- "inference_mode": OptionInfo("no-grad", "Torch inference mode", gr.Radio, {"choices": ["no-grad", "inference-mode", "none"]}),
- "sd_vae_sliced_encode": OptionInfo(False, "VAE sliced encode", gr.Checkbox, {"visible": not native}),
+}))
+
+options_templates.update(options_section(('compile', "Model Compile"), {
+ "cuda_compile_sep": OptionInfo("Model Compile
", "", gr.HTML),
+ "cuda_compile": OptionInfo([] if not cmd_opts.use_openvino else ["Model", "VAE"], "Compile Model", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder", "Upscaler"]}),
+ "cuda_compile_backend": OptionInfo("none" if not cmd_opts.use_openvino else "openvino_fx", "Model compile backend", gr.Radio, {"choices": ['none', 'inductor', 'cudagraphs', 'aot_ts_nvfuser', 'hidet', 'migraphx', 'ipex', 'onediff', 'stable-fast', 'deep-cache', 'olive-ai', 'openvino_fx']}),
+ "cuda_compile_mode": OptionInfo("default", "Model compile mode", gr.Radio, {"choices": ['default', 'reduce-overhead', 'max-autotune', 'max-autotune-no-cudagraphs']}),
+ "cuda_compile_fullgraph": OptionInfo(True if not cmd_opts.use_openvino else False, "Model compile fullgraph"),
+ "cuda_compile_precompile": OptionInfo(False, "Model compile precompile"),
+ "cuda_compile_verbose": OptionInfo(False, "Model compile verbose mode"),
+ "cuda_compile_errors": OptionInfo(True, "Model compile suppress errors"),
+ "deep_cache_interval": OptionInfo(3, "DeepCache cache interval", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
}))
options_templates.update(options_section(('system-paths', "System Paths"), {
- "models_paths_sep_options": OptionInfo("Models paths
", "", gr.HTML),
+ "models_paths_sep_options": OptionInfo("Models Paths
", "", gr.HTML),
"models_dir": OptionInfo('models', "Base path where all models are stored", folder=True),
"ckpt_dir": OptionInfo(os.path.join(paths.models_path, 'Stable-diffusion'), "Folder with stable diffusion models", folder=True),
"diffusers_dir": OptionInfo(os.path.join(paths.models_path, 'Diffusers'), "Folder with Huggingface models", folder=True),
@@ -723,13 +738,13 @@ def get_default_modes():
}))
options_templates.update(options_section(('saving-paths', "Image Paths"), {
- "saving_sep_images": OptionInfo("Save options
", "", gr.HTML),
+ "saving_sep_images": OptionInfo("Save Options
", "", gr.HTML),
"save_images_add_number": OptionInfo(True, "Numbered filenames", component_args=hide_dirs),
"use_original_name_batch": OptionInfo(True, "Batch uses original name"),
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
"samples_filename_pattern": OptionInfo("[seq]-[model_name]-[prompt_words]", "Images filename pattern", component_args=hide_dirs),
- "directories_max_prompt_words": OptionInfo(8, "Max words per pattern", gr.Slider, {"minimum": 1, "maximum": 99, "step": 1, **hide_dirs}),
+ "directories_max_prompt_words": OptionInfo(8, "Max words", gr.Slider, {"minimum": 1, "maximum": 99, "step": 1, **hide_dirs}),
"outdir_sep_dirs": OptionInfo("Folders
", "", gr.HTML),
"outdir_samples": OptionInfo("", "Images folder", component_args=hide_dirs, folder=True),
@@ -748,14 +763,14 @@ def get_default_modes():
"outdir_control_grids": OptionInfo("outputs/grids", 'Folder for control grids', component_args=hide_dirs, folder=True),
}))
-options_templates.update(options_section(('ui', "User Interface Options"), {
+options_templates.update(options_section(('ui', "User Interface"), {
"theme_type": OptionInfo("Standard", "Theme type", gr.Radio, {"choices": ["Modern", "Standard", "None"]}),
"theme_style": OptionInfo("Auto", "Theme mode", gr.Radio, {"choices": ["Auto", "Dark", "Light"]}),
"gradio_theme": OptionInfo("black-teal", "UI theme", gr.Dropdown, lambda: {"choices": theme.list_themes()}, refresh=theme.refresh_themes),
"autolaunch": OptionInfo(False, "Autolaunch browser upon startup"),
"font_size": OptionInfo(14, "Font size", gr.Slider, {"minimum": 8, "maximum": 32, "step": 1, "visible": True}),
"aspect_ratios": OptionInfo("1:1, 4:3, 3:2, 16:9, 16:10, 21:9, 2:3, 3:4, 9:16, 10:16, 9:21", "Allowed aspect ratios"),
- "motd": OptionInfo(True, "Show MOTD"),
+ "motd": OptionInfo(False, "Show MOTD"),
"compact_view": OptionInfo(False, "Compact view"),
"return_grid": OptionInfo(True, "Show grid in results"),
"return_mask": OptionInfo(False, "Inpainting include greyscale mask in results"),
@@ -767,14 +782,15 @@ def get_default_modes():
}))
options_templates.update(options_section(('live-preview', "Live Previews"), {
- "notification_audio_enable": OptionInfo(False, "Play a notification upon completion"),
- "notification_audio_path": OptionInfo("html/notification.mp3","Path to notification sound", component_args=hide_dirs, folder=True),
"show_progress_every_n_steps": OptionInfo(1, "Live preview display period", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"show_progress_type": OptionInfo("Approximate", "Live preview method", gr.Radio, {"choices": ["Simple", "Approximate", "TAESD", "Full VAE"]}),
"live_preview_refresh_period": OptionInfo(500, "Progress update period", gr.Slider, {"minimum": 0, "maximum": 5000, "step": 25}),
"live_preview_taesd_layers": OptionInfo(3, "TAESD decode layers", gr.Slider, {"minimum": 1, "maximum": 3, "step": 1}),
+ "live_preview_downscale": OptionInfo(True, "Downscale high resolution live previews"),
"logmonitor_show": OptionInfo(True, "Show log view"),
"logmonitor_refresh_period": OptionInfo(5000, "Log view update period", gr.Slider, {"minimum": 0, "maximum": 30000, "step": 25}),
+ "notification_audio_enable": OptionInfo(False, "Play a notification upon completion"),
+ "notification_audio_path": OptionInfo("html/notification.mp3","Path to notification sound", component_args=hide_dirs, folder=True),
}))
options_templates.update(options_section(('sampler-params', "Sampler Settings"), {
@@ -813,7 +829,7 @@ def get_default_modes():
's_noise': OptionInfo(1.0, "Sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01, "visible": not native}),
's_min': OptionInfo(0.0, "Sigma min", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01, "visible": not native}),
's_max': OptionInfo(0.0, "Sigma max", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 1.0, "visible": not native}),
- "schedulers_sep_compvis": OptionInfo("CompVis specific config
", "", gr.HTML, {"visible": not native}),
+ "schedulers_sep_compvis": OptionInfo("CompVis Config
", "", gr.HTML, {"visible": not native}),
'uni_pc_variant': OptionInfo("bh2", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"], "visible": not native}),
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"], "visible": not native}),
"ddim_discretize": OptionInfo('uniform', "DDIM discretize img2img", gr.Radio, {"choices": ['uniform', 'quad'], "visible": not native}),
@@ -821,9 +837,9 @@ def get_default_modes():
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
'postprocessing_enable_in_main_ui': OptionInfo([], "Additional postprocessing operations", gr.Dropdown, lambda: {"multiselect":True, "choices": [x.name for x in shared_items.postprocessing_scripts()]}),
- 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", gr.Dropdown, lambda: {"multiselect":True, "choices": [x.name for x in shared_items.postprocessing_scripts()]}),
+ 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", gr.Dropdown, lambda: {"multiselect":True, "choices": [x.name for x in shared_items.postprocessing_scripts()], "visible": False }),
- "postprocessing_sep_img2img": OptionInfo("Img2Img & Inpainting
", "", gr.HTML),
+ "postprocessing_sep_img2img": OptionInfo("Inpaint
", "", gr.HTML),
"img2img_color_correction": OptionInfo(False, "Apply color correction"),
"mask_apply_overlay": OptionInfo(True, "Apply mask as overlay"),
"img2img_background_color": OptionInfo("#ffffff", "Image transparent color fill", gr.ColorPicker, {}),
@@ -831,7 +847,7 @@ def get_default_modes():
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for image processing", gr.Slider, {"minimum": 0.1, "maximum": 1.5, "step": 0.01, "visible": not native}),
"img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01, "visible": not native}),
- # "postprocessing_sep_detailer": OptionInfo("Detailer
", "", gr.HTML),
+ "postprocessing_sep_detailer": OptionInfo("Detailer
", "", gr.HTML),
"detailer_model": OptionInfo("Detailer", "Detailer model", gr.Radio, lambda: {"choices": [x.name() for x in detailers], "visible": False}),
"detailer_classes": OptionInfo("", "Detailer classes", gr.Textbox, { "visible": False}),
"detailer_conf": OptionInfo(0.6, "Min confidence", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.05, "visible": False}),
@@ -843,11 +859,12 @@ def get_default_modes():
"detailer_blur": OptionInfo(10, "Item edge blur", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1, "visible": False}),
"detailer_strength": OptionInfo(0.5, "Detailer strength", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": False}),
"detailer_models": OptionInfo(['face-yolo8n'], "Detailer models", gr.Dropdown, lambda: {"multiselect":True, "choices": list(yolo.list), "visible": False}),
- "code_former_weight": OptionInfo(0.2, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": False}),
"detailer_unload": OptionInfo(False, "Move detailer model to CPU when complete"),
+ "detailer_augment": OptionInfo(True, "Detailer use model augment"),
- "postprocessing_sep_face_restore": OptionInfo("Face restore
", "", gr.HTML),
- "face_restoration_model": OptionInfo("Face restorer", "Face restoration", gr.Radio, lambda: {"choices": ['None'] + [x.name() for x in face_restorers]}),
+ "postprocessing_sep_face_restore": OptionInfo("Face Restore
", "", gr.HTML),
+ "face_restoration_model": OptionInfo("None", "Face restoration", gr.Radio, lambda: {"choices": ['None'] + [x.name() for x in face_restorers]}),
+ "code_former_weight": OptionInfo(0.2, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"postprocessing_sep_upscalers": OptionInfo("Upscaling
", "", gr.HTML),
"upscaler_unload": OptionInfo(False, "Unload upscaler after processing"),
@@ -857,6 +874,7 @@ def get_default_modes():
options_templates.update(options_section(('control', "Control Options"), {
"control_max_units": OptionInfo(4, "Maximum number of units", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
+ "control_tiles": OptionInfo("1x1, 1x2, 1x3, 1x4, 2x1, 2x1, 2x2, 2x3, 2x4, 3x1, 3x2, 3x3, 3x4, 4x1, 4x2, 4x3, 4x4", "Tiling options"),
"control_move_processor": OptionInfo(False, "Processor move to CPU after use"),
"control_unload_processor": OptionInfo(False, "Processor unload after use"),
}))
@@ -875,6 +893,15 @@ def get_default_modes():
"deepbooru_filter_tags": OptionInfo("", "Filter out tags from deepbooru output"),
}))
+options_templates.update(options_section(('huggingface', "Huggingface"), {
+ "huggingface_sep": OptionInfo("Huggingface
", "", gr.HTML),
+ "diffuser_cache_config": OptionInfo(True, "Use cached model config when available"),
+ "huggingface_token": OptionInfo('', 'HuggingFace token'),
+ "diffusers_model_load_variant": OptionInfo("default", "Preferred Model variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}),
+ "diffusers_vae_load_variant": OptionInfo("default", "Preferred VAE variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}),
+ "custom_diffusers_pipeline": OptionInfo('', 'Load custom Diffusers pipeline'),
+}))
+
options_templates.update(options_section(('extra_networks', "Networks"), {
"extra_networks_sep1": OptionInfo("Networks UI
", "", gr.HTML),
"extra_networks_show": OptionInfo(True, "UI show on startup"),
@@ -891,23 +918,27 @@ def get_default_modes():
"extra_networks_model_sep": OptionInfo("Models
", "", gr.HTML),
"extra_network_reference": OptionInfo(False, "Use reference values when available", gr.Checkbox),
- "extra_networks_embed_sep": OptionInfo("Embeddings
", "", gr.HTML),
- "diffusers_convert_embed": OptionInfo(False, "Auto-convert SD 1.5 embeddings to SDXL ", gr.Checkbox, {"visible": native}),
- "extra_networks_styles_sep": OptionInfo("Styles
", "", gr.HTML),
- "extra_networks_styles": OptionInfo(True, "Show built-in styles"),
- "extra_networks_wildcard_sep": OptionInfo("Wildcards
", "", gr.HTML),
- "wildcards_enabled": OptionInfo(True, "Enable file wildcards support"),
+
"extra_networks_lora_sep": OptionInfo("LoRA
", "", gr.HTML),
"extra_networks_default_multiplier": OptionInfo(1.0, "Default strength", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
- "lora_preferred_name": OptionInfo("filename", "LoRA preferred name", gr.Radio, {"choices": ["filename", "alias"]}),
- "lora_add_hashes_to_infotext": OptionInfo(False, "LoRA add hash info"),
+ "lora_preferred_name": OptionInfo("filename", "LoRA preferred name", gr.Radio, {"choices": ["filename", "alias"], "visible": False}),
+ "lora_add_hashes_to_infotext": OptionInfo(False, "LoRA add hash info to metadata"),
+ "lora_fuse_diffusers": OptionInfo(True, "LoRA fuse directly to model"),
"lora_force_diffusers": OptionInfo(False if not cmd_opts.use_openvino else True, "LoRA force loading of all models using Diffusers"),
"lora_maybe_diffusers": OptionInfo(False, "LoRA force loading of specific models using Diffusers"),
- "lora_fuse_diffusers": OptionInfo(False if not cmd_opts.use_openvino else True, "LoRA use fuse when possible"),
"lora_apply_tags": OptionInfo(0, "LoRA auto-apply tags", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"lora_in_memory_limit": OptionInfo(0, "LoRA memory cache", gr.Slider, {"minimum": 0, "maximum": 24, "step": 1}),
- "lora_quant": OptionInfo("NF4","LoRA precision in quantized models", gr.Radio, {"choices": ["NF4", "FP4"]}),
- "lora_load_gpu": OptionInfo(True if not cmd_opts.lowvram else False, "Load LoRA directly to GPU"),
+ "lora_quant": OptionInfo("NF4","LoRA precision when quantized", gr.Radio, {"choices": ["NF4", "FP4"]}),
+
+ "extra_networks_styles_sep": OptionInfo("Styles
", "", gr.HTML),
+ "extra_networks_styles": OptionInfo(True, "Show built-in styles"),
+
+ "extra_networks_embed_sep": OptionInfo("Embeddings
", "", gr.HTML),
+ "diffusers_enable_embed": OptionInfo(True, "Enable embeddings support", gr.Checkbox, {"visible": native}),
+ "diffusers_convert_embed": OptionInfo(False, "Auto-convert SD15 embeddings to SDXL", gr.Checkbox, {"visible": native}),
+
+ "extra_networks_wildcard_sep": OptionInfo("Wildcards
", "", gr.HTML),
+ "wildcards_enabled": OptionInfo(True, "Enable file wildcards support"),
}))
options_templates.update(options_section((None, "Internal options"), {
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 4d9b325c8..9abb64718 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -86,6 +86,14 @@ def get_pipelines():
'Kolors': getattr(diffusers, 'KolorsPipeline', None),
'AuraFlow': getattr(diffusers, 'AuraFlowPipeline', None),
'CogView': getattr(diffusers, 'CogView3PlusPipeline', None),
+ 'Stable Cascade': getattr(diffusers, 'StableCascadeCombinedPipeline', None),
+ 'PixArt-Sigma': getattr(diffusers, 'PixArtSigmaPipeline', None),
+ 'HunyuanDiT': getattr(diffusers, 'HunyuanDiTPipeline', None),
+ 'Stable Diffusion 3': getattr(diffusers, 'StableDiffusion3Pipeline', None),
+ 'Stable Diffusion 3 Img2Img': getattr(diffusers, 'StableDiffusion3Img2ImgPipeline', None),
+ 'Lumina-Next': getattr(diffusers, 'LuminaText2ImgPipeline', None),
+ 'FLUX': getattr(diffusers, 'FluxPipeline', None),
+ 'Sana': getattr(diffusers, 'SanaPAGPipeline', None),
}
if hasattr(diffusers, 'OnnxStableDiffusionPipeline'):
onnx_pipelines = {
@@ -103,19 +111,10 @@ def get_pipelines():
pipelines.update(onnx_pipelines)
# items that may rely on diffusers dev version
- if hasattr(diffusers, 'StableCascadeCombinedPipeline'):
- pipelines['Stable Cascade'] = getattr(diffusers, 'StableCascadeCombinedPipeline', None)
- if hasattr(diffusers, 'PixArtSigmaPipeline'):
- pipelines['PixArt-Sigma'] = getattr(diffusers, 'PixArtSigmaPipeline', None)
- if hasattr(diffusers, 'HunyuanDiTPipeline'):
- pipelines['HunyuanDiT'] = getattr(diffusers, 'HunyuanDiTPipeline', None)
- if hasattr(diffusers, 'StableDiffusion3Pipeline'):
- pipelines['Stable Diffusion 3'] = getattr(diffusers, 'StableDiffusion3Pipeline', None)
- pipelines['Stable Diffusion 3 Img2Img'] = getattr(diffusers, 'StableDiffusion3Img2ImgPipeline', None)
- if hasattr(diffusers, 'LuminaText2ImgPipeline'):
- pipelines['Lumina-Next'] = getattr(diffusers, 'LuminaText2ImgPipeline', None)
+ """
if hasattr(diffusers, 'FluxPipeline'):
pipelines['FLUX'] = getattr(diffusers, 'FluxPipeline', None)
+ """
for k, v in pipelines.items():
if k != 'Autodetect' and v is None:
diff --git a/modules/shared_state.py b/modules/shared_state.py
index 7def42b8c..a3312ec33 100644
--- a/modules/shared_state.py
+++ b/modules/shared_state.py
@@ -28,6 +28,9 @@ class State:
oom = False
debug_output = os.environ.get('SD_STATE_DEBUG', None)
+ def __str__(self) -> str:
+ return f'State: job={self.job} {self.job_no}/{self.job_count} step={self.sampling_step}/{self.sampling_steps} skipped={self.skipped} interrupted={self.interrupted} paused={self.paused} info={self.textinfo}'
+
def skip(self):
log.debug('Requested skip')
self.skipped = True
@@ -135,11 +138,12 @@ def end(self, api=None):
modules.devices.torch_gc()
def set_current_image(self):
+ if self.job == 'VAE': # avoid generating preview while vae is running
+ return
from modules.shared import opts, cmd_opts
- """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
- if cmd_opts.lowvram or self.api:
+ if cmd_opts.lowvram or self.api or not opts.live_previews_enable or opts.show_progress_every_n_steps <= 0:
return
- if abs(self.sampling_step - self.current_image_sampling_step) >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps > 0:
+ if abs(self.sampling_step - self.current_image_sampling_step) >= opts.show_progress_every_n_steps:
self.do_set_current_image()
def do_set_current_image(self):
diff --git a/modules/style_aligned/inversion.py b/modules/style_aligned/inversion.py
new file mode 100644
index 000000000..8c91cc02a
--- /dev/null
+++ b/modules/style_aligned/inversion.py
@@ -0,0 +1,124 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from __future__ import annotations
+from typing import Callable, TYPE_CHECKING
+from diffusers import StableDiffusionXLPipeline
+import torch
+from tqdm import tqdm
+if TYPE_CHECKING:
+ import numpy as np
+
+
+T = torch.Tensor
+TN = T
+InversionCallback = Callable[[StableDiffusionXLPipeline, int, T, dict[str, T]], dict[str, T]]
+
+
+def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device):
+ # Tokenize text and get embeddings
+ text_inputs = tokenizer(prompt, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt')
+ text_input_ids = text_inputs.input_ids
+
+ with torch.no_grad():
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ if prompt == '':
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ return negative_prompt_embeds, negative_pooled_prompt_embeds
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def _encode_text_sdxl(model: StableDiffusionXLPipeline, prompt: str) -> tuple[dict[str, T], T]:
+ device = model._execution_device # pylint: disable=protected-access
+ prompt_embeds, pooled_prompt_embeds, = _get_text_embeddings(prompt, model.tokenizer, model.text_encoder, device) # pylint: disable=unused-variable
+ prompt_embeds_2, pooled_prompt_embeds2, = _get_text_embeddings( prompt, model.tokenizer_2, model.text_encoder_2, device)
+ prompt_embeds = torch.cat((prompt_embeds, prompt_embeds_2), dim=-1)
+ text_encoder_projection_dim = model.text_encoder_2.config.projection_dim
+ add_time_ids = model._get_add_time_ids((1024, 1024), (0, 0), (1024, 1024), model.text_encoder.dtype, # pylint: disable=protected-access
+ text_encoder_projection_dim).to(device)
+ added_cond_kwargs = {"text_embeds": pooled_prompt_embeds2, "time_ids": add_time_ids}
+ return added_cond_kwargs, prompt_embeds
+
+
+def _encode_text_sdxl_with_negative(model: StableDiffusionXLPipeline, prompt: str) -> tuple[dict[str, T], T]:
+ added_cond_kwargs, prompt_embeds = _encode_text_sdxl(model, prompt)
+ added_cond_kwargs_uncond, prompt_embeds_uncond = _encode_text_sdxl(model, "")
+ prompt_embeds = torch.cat((prompt_embeds_uncond, prompt_embeds, ))
+ added_cond_kwargs = {"text_embeds": torch.cat((added_cond_kwargs_uncond["text_embeds"], added_cond_kwargs["text_embeds"])),
+ "time_ids": torch.cat((added_cond_kwargs_uncond["time_ids"], added_cond_kwargs["time_ids"])),}
+ return added_cond_kwargs, prompt_embeds
+
+
+def _encode_image(model: StableDiffusionXLPipeline, image: np.ndarray) -> T:
+ image = torch.from_numpy(image).float() / 255.
+ image = (image * 2 - 1).permute(2, 0, 1).unsqueeze(0)
+ latent = model.vae.encode(image.to(model.vae.device, model.vae.dtype))['latent_dist'].mean * model.vae.config.scaling_factor
+ return latent
+
+
+def _next_step(model: StableDiffusionXLPipeline, model_output: T, timestep: int, sample: T) -> T:
+ timestep, next_timestep = min(timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep
+ alpha_prod_t = model.scheduler.alphas_cumprod[int(timestep)] if timestep >= 0 else model.scheduler.final_alpha_cumprod
+ alpha_prod_t_next = model.scheduler.alphas_cumprod[int(next_timestep)]
+ beta_prod_t = 1 - alpha_prod_t
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
+ return next_sample
+
+
+def _get_noise_pred(model: StableDiffusionXLPipeline, latent: T, t: T, context: T, guidance_scale: float, added_cond_kwargs: dict[str, T]):
+ latents_input = torch.cat([latent] * 2)
+ noise_pred = model.unet(latents_input, t, encoder_hidden_states=context, added_cond_kwargs=added_cond_kwargs)["sample"]
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+ # latents = next_step(model, noise_pred, t, latent)
+ return noise_pred
+
+
+def _ddim_loop(model: StableDiffusionXLPipeline, z0, prompt, guidance_scale) -> T:
+ all_latent = [z0]
+ added_cond_kwargs, text_embedding = _encode_text_sdxl_with_negative(model, prompt)
+ latent = z0.clone().detach().to(model.text_encoder.dtype)
+ for i in tqdm(range(model.scheduler.num_inference_steps)):
+ t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1]
+ noise_pred = _get_noise_pred(model, latent, t, text_embedding, guidance_scale, added_cond_kwargs)
+ latent = _next_step(model, noise_pred, t, latent)
+ all_latent.append(latent)
+ return torch.cat(all_latent).flip(0)
+
+
+def make_inversion_callback(zts, offset: int = 0):
+
+ def callback_on_step_end(pipeline: StableDiffusionXLPipeline, i: int, t: T, callback_kwargs: dict[str, T]) -> dict[str, T]: # pylint: disable=unused-argument
+ latents = callback_kwargs['latents']
+ latents[0] = zts[max(offset + 1, i + 1)].to(latents.device, latents.dtype)
+ return {'latents': latents}
+ return zts[offset], callback_on_step_end
+
+
+@torch.no_grad()
+def ddim_inversion(model: StableDiffusionXLPipeline, x0: np.ndarray, prompt: str, num_inference_steps: int, guidance_scale,) -> T:
+ z0 = _encode_image(model, x0)
+ model.scheduler.set_timesteps(num_inference_steps, device=z0.device)
+ zs = _ddim_loop(model, z0, prompt, guidance_scale)
+ return zs
diff --git a/modules/style_aligned/sa_handler.py b/modules/style_aligned/sa_handler.py
new file mode 100644
index 000000000..ee4b1ca79
--- /dev/null
+++ b/modules/style_aligned/sa_handler.py
@@ -0,0 +1,281 @@
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from __future__ import annotations
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from diffusers import StableDiffusionXLPipeline
+from dataclasses import dataclass
+import torch
+import torch.nn as nn
+from torch.nn import functional as nnf
+from diffusers.models import attention_processor # pylint: disable=ungrouped-imports
+import einops
+
+T = torch.Tensor
+
+
+@dataclass(frozen=True)
+class StyleAlignedArgs:
+ share_group_norm: bool = True
+ share_layer_norm: bool = True
+ share_attention: bool = True
+ adain_queries: bool = True
+ adain_keys: bool = True
+ adain_values: bool = False
+ full_attention_share: bool = False
+ shared_score_scale: float = 1.
+ shared_score_shift: float = 0.
+ only_self_level: float = 0.
+
+
+def expand_first(feat: T, scale=1.,) -> T:
+ b = feat.shape[0]
+ feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
+ if scale == 1:
+ feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
+ else:
+ feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
+ feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
+ return feat_style.reshape(*feat.shape)
+
+
+def concat_first(feat: T, dim=2, scale=1.) -> T:
+ feat_style = expand_first(feat, scale=scale)
+ return torch.cat((feat, feat_style), dim=dim)
+
+
+def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]:
+ feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
+ feat_mean = feat.mean(dim=-2, keepdims=True)
+ return feat_mean, feat_std
+
+
+def adain(feat: T) -> T:
+ feat_mean, feat_std = calc_mean_std(feat)
+ feat_style_mean = expand_first(feat_mean)
+ feat_style_std = expand_first(feat_std)
+ feat = (feat - feat_mean) / feat_std
+ feat = feat * feat_style_std + feat_style_mean
+ return feat
+
+
+class DefaultAttentionProcessor(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.processor = attention_processor.AttnProcessor2_0()
+
+ def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None,
+ attention_mask=None, **kwargs):
+ return self.processor(attn, hidden_states, encoder_hidden_states, attention_mask)
+
+
+class SharedAttentionProcessor(DefaultAttentionProcessor):
+
+ def shifted_scaled_dot_product_attention(self, attn: attention_processor.Attention, query: T, key: T, value: T) -> T:
+ logits = torch.einsum('bhqd,bhkd->bhqk', query, key) * attn.scale
+ logits[:, :, :, query.shape[2]:] += self.shared_score_shift
+ probs = logits.softmax(-1)
+ return torch.einsum('bhqk,bhkd->bhqd', probs, value)
+
+ def shared_call( # pylint: disable=unused-argument
+ self,
+ attn: attention_processor.Attention,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ **kwargs
+ ):
+
+ residual = hidden_states
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ # if self.step >= self.start_inject:
+ if self.adain_queries:
+ query = adain(query)
+ if self.adain_keys:
+ key = adain(key)
+ if self.adain_values:
+ value = adain(value)
+ if self.share_attention:
+ key = concat_first(key, -2, scale=self.shared_score_scale)
+ value = concat_first(value, -2)
+ if self.shared_score_shift != 0:
+ hidden_states = self.shifted_scaled_dot_product_attention(attn, query, key, value,)
+ else:
+ hidden_states = nnf.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ else:
+ hidden_states = nnf.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ # hidden_states = adain(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+ return hidden_states
+
+ def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None,
+ attention_mask=None, **kwargs):
+ if self.full_attention_share:
+ _b, n, _d = hidden_states.shape
+ hidden_states = einops.rearrange(hidden_states, '(k b) n d -> k (b n) d', k=2)
+ hidden_states = super().__call__(attn, hidden_states, encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask, **kwargs)
+ hidden_states = einops.rearrange(hidden_states, 'k (b n) d -> (k b) n d', n=n)
+ else:
+ hidden_states = self.shared_call(attn, hidden_states, hidden_states, attention_mask, **kwargs)
+
+ return hidden_states
+
+ def __init__(self, style_aligned_args: StyleAlignedArgs):
+ super().__init__()
+ self.share_attention = style_aligned_args.share_attention
+ self.adain_queries = style_aligned_args.adain_queries
+ self.adain_keys = style_aligned_args.adain_keys
+ self.adain_values = style_aligned_args.adain_values
+ self.full_attention_share = style_aligned_args.full_attention_share
+ self.shared_score_scale = style_aligned_args.shared_score_scale
+ self.shared_score_shift = style_aligned_args.shared_score_shift
+
+
+def _get_switch_vec(total_num_layers, level):
+ if level <= 0:
+ return torch.zeros(total_num_layers, dtype=torch.bool)
+ if level >= 1:
+ return torch.ones(total_num_layers, dtype=torch.bool)
+ to_flip = level > .5
+ if to_flip:
+ level = 1 - level
+ num_switch = int(level * total_num_layers)
+ vec = torch.arange(total_num_layers)
+ vec = vec % (total_num_layers // num_switch)
+ vec = vec == 0
+ if to_flip:
+ vec = ~vec
+ return vec
+
+
+def init_attention_processors(pipeline: StableDiffusionXLPipeline, style_aligned_args: StyleAlignedArgs | None = None):
+ attn_procs = {}
+ unet = pipeline.unet
+ number_of_self, number_of_cross = 0, 0
+ num_self_layers = len([name for name in unet.attn_processors.keys() if 'attn1' in name])
+ if style_aligned_args is None:
+ only_self_vec = _get_switch_vec(num_self_layers, 1)
+ else:
+ only_self_vec = _get_switch_vec(num_self_layers, style_aligned_args.only_self_level)
+ for i, name in enumerate(unet.attn_processors.keys()):
+ is_self_attention = 'attn1' in name
+ if is_self_attention:
+ number_of_self += 1
+ if style_aligned_args is None or only_self_vec[i // 2]:
+ attn_procs[name] = DefaultAttentionProcessor()
+ else:
+ attn_procs[name] = SharedAttentionProcessor(style_aligned_args)
+ else:
+ number_of_cross += 1
+ attn_procs[name] = DefaultAttentionProcessor()
+
+ unet.set_attn_processor(attn_procs)
+
+
+def register_shared_norm(pipeline: StableDiffusionXLPipeline,
+ share_group_norm: bool = True,
+ share_layer_norm: bool = True,
+ ):
+ def register_norm_forward(norm_layer: nn.GroupNorm | nn.LayerNorm) -> nn.GroupNorm | nn.LayerNorm:
+ if not hasattr(norm_layer, 'orig_forward'):
+ setattr(norm_layer, 'orig_forward', norm_layer.forward) # noqa
+ orig_forward = norm_layer.orig_forward
+
+ def forward_(hidden_states: T) -> T:
+ n = hidden_states.shape[-2]
+ hidden_states = concat_first(hidden_states, dim=-2)
+ hidden_states = orig_forward(hidden_states)
+ return hidden_states[..., :n, :]
+
+ norm_layer.forward = forward_
+ return norm_layer
+
+ def get_norm_layers(pipeline_, norm_layers_: dict[str, list[nn.GroupNorm | nn.LayerNorm]]):
+ if isinstance(pipeline_, nn.LayerNorm) and share_layer_norm:
+ norm_layers_['layer'].append(pipeline_)
+ if isinstance(pipeline_, nn.GroupNorm) and share_group_norm:
+ norm_layers_['group'].append(pipeline_)
+ else:
+ for layer in pipeline_.children():
+ get_norm_layers(layer, norm_layers_)
+
+ norm_layers = {'group': [], 'layer': []}
+ get_norm_layers(pipeline.unet, norm_layers)
+ return [register_norm_forward(layer) for layer in norm_layers['group']] + [register_norm_forward(layer) for layer in
+ norm_layers['layer']]
+
+
+class Handler:
+
+ def register(self, style_aligned_args: StyleAlignedArgs):
+ self.norm_layers = register_shared_norm(self.pipeline, style_aligned_args.share_group_norm,
+ style_aligned_args.share_layer_norm)
+ init_attention_processors(self.pipeline, style_aligned_args)
+
+ def remove(self):
+ for layer in self.norm_layers:
+ layer.forward = layer.orig_forward
+ self.norm_layers = []
+ init_attention_processors(self.pipeline, None)
+
+ def __init__(self, pipeline: StableDiffusionXLPipeline):
+ self.pipeline = pipeline
+ self.norm_layers = []
diff --git a/modules/styles.py b/modules/styles.py
index 0dd48eb7f..d0228d33a 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -52,7 +52,7 @@ def check_files(prompt, wildcard, files):
choice = random.choice(lines).strip(' \n')
if '|' in choice:
choice = random.choice(choice.split('|')).strip(' []{}\n')
- prompt = prompt.replace(f"__{wildcard}__", choice)
+ prompt = prompt.replace(f"__{wildcard}__", choice, 1)
shared.log.debug(f'Wildcards apply: wildcard="{wildcard}" choice="{choice}" file="{file}" choices={len(lines)}')
replaced.append(wildcard)
return prompt, True
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index dd4203a4f..de12021a5 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -4,13 +4,14 @@
import torch
import safetensors.torch
from PIL import Image
-from modules import shared, devices, sd_models, errors
+from modules import shared, devices, errors
from modules.textual_inversion.image_embedding import embedding_from_b64, extract_image_data_embed
from modules.files_cache import directory_files, directory_mtime, extension_filter
debug = shared.log.trace if os.environ.get('SD_TI_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: TEXTUAL INVERSION')
+supported_models = ['ldm', 'sd', 'sdxl']
def list_embeddings(*dirs):
@@ -274,6 +275,8 @@ def load_diffusers_embedding(self, filename: Union[str, List[str]] = None, data:
overwrite = bool(data)
if not shared.sd_loaded:
return
+ if not shared.opts.diffusers_enable_embed:
+ return
embeddings, skipped = open_embeddings(filename) or convert_bundled(data)
for skip in skipped:
self.skipped_embeddings[skip.name] = skipped
@@ -368,7 +371,7 @@ def load_from_file(self, path, filename):
self.skipped_embeddings[name] = embedding
def load_from_dir(self, embdir):
- if sd_models.model_data.sd_model is None:
+ if not shared.sd_loaded:
shared.log.info('Skipping embeddings load: model not loaded')
return
if not os.path.isdir(embdir.path):
@@ -388,6 +391,8 @@ def load_from_dir(self, embdir):
def load_textual_inversion_embeddings(self, force_reload=False):
if not shared.sd_loaded:
return
+ if shared.sd_model_type not in supported_models:
+ return
t0 = time.time()
if not force_reload:
need_reload = False
diff --git a/modules/timer.py b/modules/timer.py
index 8a5db726d..43e859140 100644
--- a/modules/timer.py
+++ b/modules/timer.py
@@ -7,6 +7,7 @@ def __init__(self):
self.start = time.time()
self.records = {}
self.total = 0
+ self.profile = False
def elapsed(self, reset=True):
end = time.time()
@@ -15,17 +16,24 @@ def elapsed(self, reset=True):
self.start = end
return res
+ def add(self, name, t):
+ if name not in self.records:
+ self.records[name] = t
+ else:
+ self.records[name] += t
+
def record(self, category=None, extra_time=0, reset=True):
e = self.elapsed(reset)
if category is None:
category = sys._getframe(1).f_code.co_name # pylint: disable=protected-access
if category not in self.records:
self.records[category] = 0
-
self.records[category] += e + extra_time
self.total += e + extra_time
def summary(self, min_time=0.05, total=True):
+ if self.profile:
+ min_time = -1
res = f"{self.total:.2f} " if total else ''
additions = [x for x in self.records.items() if x[1] >= min_time]
if not additions:
@@ -34,6 +42,8 @@ def summary(self, min_time=0.05, total=True):
return res
def dct(self, min_time=0.05):
+ if self.profile:
+ return {k: round(v, 4) for k, v in self.records.items()}
return {k: round(v, 2) for k, v in self.records.items() if v >= min_time}
def reset(self):
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 2f0e2f4b3..e82c744a2 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -88,7 +88,7 @@ def txt2img(id_task, state,
p.scripts = scripts.scripts_txt2img
p.script_args = args
p.state = state
- processed = scripts.scripts_txt2img.run(p, *args)
+ processed: processing.Processed = scripts.scripts_txt2img.run(p, *args)
if processed is None:
processed = processing.process_images(p)
processed = scripts.scripts_txt2img.after(p, processed, *args)
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 9c4bb5cdc..3e7c68bec 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -245,10 +245,18 @@ def create_output_panel(tabname, preview=True, prompt=None, height=None):
gr.HTML(value="", elem_id="main_info", visible=False, elem_classes=["main-info"])
# columns are for <576px, <768px, <992px, <1200px, <1400px, >1400px
result_gallery = gr.Gallery(value=[],
- label='Output', show_label=False, show_download_button=True, allow_preview=True, container=False, preview=preview,
- columns=4, object_fit='scale-down', height=height,
+ label='Output',
+ show_label=False,
+ show_download_button=True,
+ allow_preview=True,
+ container=False,
+ preview=preview,
+ columns=4,
+ object_fit='scale-down',
+ height=height,
elem_id=f"{tabname}_gallery",
- )
+ elem_classes=["gallery_main"],
+ )
if prompt is not None:
interrogate_clip_btn, interrogate_booru_btn = ui_sections.create_interrogate_buttons('control')
interrogate_clip_btn.click(fn=interrogate_clip, inputs=[result_gallery], outputs=[prompt])
diff --git a/modules/ui_control.py b/modules/ui_control.py
index 0bf070036..cd78313a7 100644
--- a/modules/ui_control.py
+++ b/modules/ui_control.py
@@ -9,11 +9,11 @@
from modules.control.units import lite # vislearn ControlNet-XS
from modules.control.units import t2iadapter # TencentARC T2I-Adapter
from modules.control.units import reference # reference pipeline
-from modules import errors, shared, progress, ui_components, ui_symbols, ui_common, ui_sections, generation_parameters_copypaste, call_queue, scripts, masking, images, processing_vae # pylint: disable=ungrouped-imports
+from modules import errors, shared, progress, ui_components, ui_symbols, ui_common, ui_sections, generation_parameters_copypaste, call_queue, scripts, masking, images, processing_vae, timer # pylint: disable=ungrouped-imports
from modules import ui_control_helpers as helpers
-gr_height = None
+gr_height = 512
max_units = shared.opts.control_max_units
units: list[unit.Unit] = [] # main state variable
controls: list[gr.component] = [] # list of gr controls
@@ -21,13 +21,41 @@
debug('Trace: CONTROL')
-def return_controls(res):
+def return_stats(t: float = None):
+ if t is None:
+ elapsed_text = ''
+ else:
+ elapsed = time.perf_counter() - t
+ elapsed_m = int(elapsed // 60)
+ elapsed_s = elapsed % 60
+ elapsed_text = f"Time: {elapsed_m}m {elapsed_s:.2f}s |" if elapsed_m > 0 else f"Time: {elapsed_s:.2f}s |"
+ summary = timer.process.summary(min_time=0.25, total=False).replace('=', ' ')
+ gpu = ''
+ cpu = ''
+ if not shared.mem_mon.disabled:
+ vram = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.read().items()}
+ peak = max(vram['active_peak'], vram['reserved_peak'], vram['used'])
+ used = round(100.0 * peak / vram['total']) if vram['total'] > 0 else 0
+ if used > 0:
+ gpu += f"| GPU {peak} MB {used}%"
+ gpu += f" | retries {vram['retries']} oom {vram['oom']}" if vram.get('retries', 0) > 0 or vram.get('oom', 0) > 0 else ''
+ ram = shared.ram_stats()
+ if ram['used'] > 0:
+ cpu += f"| RAM {ram['used']} GB {round(100.0 * ram['used'] / ram['total'])}%"
+ return f""
+
+
+def return_controls(res, t: float = None):
# return preview, image, video, gallery, text
debug(f'Control received: type={type(res)} {res}')
+ if t is None:
+ perf = ''
+ else:
+ perf = return_stats(t)
if res is None: # no response
- return [None, None, None, None, '']
+ return [None, None, None, None, '', perf]
elif isinstance(res, str): # error response
- return [None, None, None, None, res]
+ return [None, None, None, None, res, perf]
elif isinstance(res, tuple): # standard response received as tuple via control_run->yield(output_images, process_image, result_txt)
preview_image = res[1] # may be None
output_image = res[0][0] if isinstance(res[0], list) else res[0] # may be image or list of images
@@ -37,9 +65,9 @@ def return_controls(res):
output_gallery = [res[0]] if res[0] is not None else [] # must return list, but can receive single image
result_txt = res[2] if len(res) > 2 else '' # do we have a message
output_video = res[3] if len(res) > 3 else None # do we have a video filename
- return [preview_image, output_image, output_video, output_gallery, result_txt]
+ return [preview_image, output_image, output_video, output_gallery, result_txt, perf]
else: # unexpected
- return [None, None, None, None, f'Control: Unexpected response: {type(res)}']
+ return [None, None, None, None, f'Control: Unexpected response: {type(res)}', perf]
def get_units(*values):
@@ -67,17 +95,18 @@ def generate_click(job_id: str, state: str, active_tab: str, *args):
shared.state.begin('Generate')
progress.add_task_to_queue(job_id)
with call_queue.queue_lock:
- yield [None, None, None, None, 'Control: starting']
+ yield [None, None, None, None, 'Control: starting', '']
shared.mem_mon.reset()
progress.start_task(job_id)
try:
+ t = time.perf_counter()
for results in control_run(state, units, helpers.input_source, helpers.input_init, helpers.input_mask, active_tab, True, *args):
progress.record_results(job_id, results)
- yield return_controls(results)
+ yield return_controls(results, t)
except Exception as e:
shared.log.error(f"Control exception: {e}")
errors.display(e, 'Control')
- yield [None, None, None, None, f'Control: Exception: {e}']
+ yield [None, None, None, None, f'Control: Exception: {e}', '']
progress.finish_task(job_id)
shared.state.end()
@@ -106,11 +135,12 @@ def create_ui(_blocks: gr.Blocks=None):
with gr.Accordion(open=False, label="Input", elem_id="control_input", elem_classes=["small-accordion"]):
with gr.Row():
- show_preview = gr.Checkbox(label="Show preview", value=True, elem_id="control_show_preview")
+ show_input = gr.Checkbox(label="Show input", value=True, elem_id="control_show_input")
+ show_preview = gr.Checkbox(label="Show preview", value=False, elem_id="control_show_preview")
with gr.Row():
- input_type = gr.Radio(label="Input type", choices=['Control only', 'Init image same as control', 'Separate init image'], value='Control only', type='index', elem_id='control_input_type')
+ input_type = gr.Radio(label="Control input type", choices=['Control only', 'Init image same as control', 'Separate init image'], value='Control only', type='index', elem_id='control_input_type')
with gr.Row():
- denoising_strength = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Denoising strength', value=0.50, elem_id="control_input_denoising_strength")
+ denoising_strength = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Denoising strength', value=0.30, elem_id="control_input_denoising_strength")
with gr.Accordion(open=False, label="Size", elem_id="control_size", elem_classes=["small-accordion"]):
with gr.Tabs():
@@ -153,13 +183,13 @@ def create_ui(_blocks: gr.Blocks=None):
override_settings = ui_common.create_override_inputs('control')
with gr.Row(variant='compact', elem_id="control_extra_networks", elem_classes=["extra_networks_root"], visible=False) as extra_networks_ui:
- from modules import timer, ui_extra_networks
+ from modules import ui_extra_networks
extra_networks_ui = ui_extra_networks.create_ui(extra_networks_ui, btn_extra, 'control', skip_indexing=shared.opts.extra_network_skip_indexing)
timer.startup.record('ui-networks')
with gr.Row(elem_id='control-inputs'):
- with gr.Column(scale=9, elem_id='control-input-column', visible=True) as _column_input:
- gr.HTML('Control input
')
+ with gr.Column(scale=9, elem_id='control-input-column', visible=True) as column_input:
+ gr.HTML('Input')
with gr.Tabs(elem_classes=['control-tabs'], elem_id='control-tab-input'):
with gr.Tab('Image', id='in-image') as tab_image:
input_mode = gr.Label(value='select', visible=False)
@@ -190,12 +220,12 @@ def create_ui(_blocks: gr.Blocks=None):
gr.HTML('Output')
with gr.Tabs(elem_classes=['control-tabs'], elem_id='control-tab-output') as output_tabs:
with gr.Tab('Gallery', id='out-gallery'):
- output_gallery, _output_gen_info, _output_html_info, _output_html_info_formatted, _output_html_log = ui_common.create_output_panel("control", preview=True, prompt=prompt, height=gr_height)
+ output_gallery, _output_gen_info, _output_html_info, _output_html_info_formatted, output_html_log = ui_common.create_output_panel("control", preview=True, prompt=prompt, height=gr_height)
with gr.Tab('Image', id='out-image'):
output_image = gr.Image(label="Output", show_label=False, type="pil", interactive=False, tool="editor", height=gr_height, elem_id='control_output_image', elem_classes=['control-image'])
with gr.Tab('Video', id='out-video'):
output_video = gr.Video(label="Output", show_label=False, height=gr_height, elem_id='control_output_video', elem_classes=['control-image'])
- with gr.Column(scale=9, elem_id='control-preview-column', visible=True) as column_preview:
+ with gr.Column(scale=9, elem_id='control-preview-column', visible=False) as column_preview:
gr.HTML('Preview')
with gr.Tabs(elem_classes=['control-tabs'], elem_id='control-tab-preview'):
with gr.Tab('Preview', id='preview-image') as _tab_preview:
@@ -220,10 +250,11 @@ def create_ui(_blocks: gr.Blocks=None):
process_id = gr.Dropdown(label="Processor", choices=processors.list_models(), value='None', elem_id=f'control_unit-{i}-process_name')
model_id = gr.Dropdown(label="ControlNet", choices=controlnet.list_models(), value='None', elem_id=f'control_unit-{i}-model_name')
ui_common.create_refresh_button(model_id, controlnet.list_models, lambda: {"choices": controlnet.list_models(refresh=True)}, f'refresh_controlnet_models_{i}')
+ control_mode = gr.Dropdown(label="CN Mode", choices=['default'], value='default', visible=False, elem_id=f'control_unit-{i}-mode')
model_strength = gr.Slider(label="CN Strength", minimum=0.01, maximum=2.0, step=0.01, value=1.0, elem_id=f'control_unit-{i}-strength')
- control_start = gr.Slider(label="Start", minimum=0.0, maximum=1.0, step=0.05, value=0, elem_id=f'control_unit-{i}-start')
- control_end = gr.Slider(label="End", minimum=0.0, maximum=1.0, step=0.05, value=1.0, elem_id=f'control_unit-{i}-end')
- control_mode = gr.Dropdown(label="CN Mode", choices=['', 'Canny', 'Tile', 'Depth', 'Blur', 'Pose', 'Gray', 'LQ'], value=0, type='index', visible=False, elem_id=f'control_unit-{i}-mode')
+ control_start = gr.Slider(label="CN Start", minimum=0.0, maximum=1.0, step=0.05, value=0, elem_id=f'control_unit-{i}-start')
+ control_end = gr.Slider(label="CN End", minimum=0.0, maximum=1.0, step=0.05, value=1.0, elem_id=f'control_unit-{i}-end')
+ control_tile = gr.Dropdown(label="CN Tiles", choices=[x.strip() for x in shared.opts.control_tiles.split(',') if 'x' in x], value='1x1', visible=False, elem_id=f'control_unit-{i}-tile')
reset_btn = ui_components.ToolButton(value=ui_symbols.reset)
image_upload = gr.UploadButton(label=ui_symbols.upload, file_types=['image'], elem_classes=['form', 'gradio-button', 'tool'])
image_reuse= ui_components.ToolButton(value=ui_symbols.reuse)
@@ -248,6 +279,7 @@ def create_ui(_blocks: gr.Blocks=None):
control_start = control_start,
control_end = control_end,
control_mode = control_mode,
+ control_tile = control_tile,
extra_controls = extra_controls,
)
)
@@ -498,6 +530,7 @@ def create_ui(_blocks: gr.Blocks=None):
btn_update = gr.Button('Update', interactive=True, visible=False, elem_id='control_update')
btn_update.click(fn=get_units, inputs=controls, outputs=[], show_progress=True, queue=False)
+ show_input.change(fn=lambda x: gr.update(visible=x), inputs=[show_input], outputs=[column_input])
show_preview.change(fn=lambda x: gr.update(visible=x), inputs=[show_preview], outputs=[column_preview])
input_type.change(fn=lambda x: gr.update(visible=x == 2), inputs=[input_type], outputs=[column_init])
btn_prompt_counter.click(fn=call_queue.wrap_queued_call(ui_common.update_token_counter), inputs=[prompt, steps], outputs=[prompt_counter])
@@ -550,6 +583,7 @@ def create_ui(_blocks: gr.Blocks=None):
output_video,
output_gallery,
result_txt,
+ output_html_log,
]
control_dict = dict(
fn=generate_click,
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index f6e6cee97..94664c5cb 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -16,7 +16,7 @@
import gradio as gr
from PIL import Image
from starlette.responses import FileResponse, JSONResponse
-from modules import paths, shared, scripts, files_cache, errors, infotext
+from modules import paths, shared, files_cache, errors, infotext
from modules.ui_components import ToolButton
import modules.ui_symbols as symbols
@@ -135,6 +135,7 @@ def patch(self, text: str, tabname: str):
return text.replace('~tabname', tabname)
def create_xyz_grid(self):
+ """
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module
def add_prompt(p, opt, x):
@@ -150,6 +151,7 @@ def add_prompt(p, opt, x):
opt = xyz_grid.AxisOption(f"[Network] {self.title}", str, add_prompt, choices=lambda: [x["name"] for x in self.items])
if opt not in xyz_grid.axis_options:
xyz_grid.axis_options.append(opt)
+ """
def link_preview(self, filename):
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
@@ -458,17 +460,20 @@ def register_page(page: ExtraNetworksPage):
def register_pages():
- from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
+ debug('EN register-pages')
from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
- from modules.ui_extra_networks_styles import ExtraNetworksPageStyles
from modules.ui_extra_networks_vae import ExtraNetworksPageVAEs
+ from modules.ui_extra_networks_styles import ExtraNetworksPageStyles
from modules.ui_extra_networks_history import ExtraNetworksPageHistory
- debug('EN register-pages')
+ from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
register_page(ExtraNetworksPageCheckpoints())
- register_page(ExtraNetworksPageStyles())
- register_page(ExtraNetworksPageTextualInversion())
register_page(ExtraNetworksPageVAEs())
+ register_page(ExtraNetworksPageStyles())
register_page(ExtraNetworksPageHistory())
+ register_page(ExtraNetworksPageTextualInversion())
+ if shared.native:
+ from modules.ui_extra_networks_lora import ExtraNetworksPageLora
+ register_page(ExtraNetworksPageLora())
if shared.opts.hypernetwork_enabled:
from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
register_page(ExtraNetworksPageHypernetworks())
diff --git a/modules/ui_extra_networks_lora.py b/modules/ui_extra_networks_lora.py
new file mode 100644
index 000000000..9dd1b3573
--- /dev/null
+++ b/modules/ui_extra_networks_lora.py
@@ -0,0 +1,123 @@
+import os
+import json
+import concurrent
+import modules.lora.networks as networks
+from modules import shared, ui_extra_networks
+
+
+debug = os.environ.get('SD_LORA_DEBUG', None) is not None
+
+
+class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Lora')
+ self.list_time = 0
+
+ def refresh(self):
+ networks.list_available_networks()
+
+ @staticmethod
+ def get_tags(l, info):
+ tags = {}
+ try:
+ if l.metadata is not None:
+ modelspec_tags = l.metadata.get('modelspec.tags', {})
+ possible_tags = l.metadata.get('ss_tag_frequency', {}) # tags from model metedata
+ if isinstance(possible_tags, str):
+ possible_tags = {}
+ if isinstance(modelspec_tags, str):
+ modelspec_tags = {}
+ if len(list(modelspec_tags)) > 0:
+ possible_tags.update(modelspec_tags)
+ for k, v in possible_tags.items():
+ words = k.split('_', 1) if '_' in k else [v, k]
+ words = [str(w).replace('.json', '') for w in words]
+ if words[0] == '{}':
+ words[0] = 0
+ tag = ' '.join(words[1:]).lower()
+ tags[tag] = words[0]
+
+ def find_version():
+ found_versions = []
+ current_hash = l.hash[:8].upper()
+ all_versions = info.get('modelVersions', [])
+ for v in info.get('modelVersions', []):
+ for f in v.get('files', []):
+ if any(h.startswith(current_hash) for h in f.get('hashes', {}).values()):
+ found_versions.append(v)
+ if len(found_versions) == 0:
+ found_versions = all_versions
+ return found_versions
+
+ for v in find_version(): # trigger words from info json
+ possible_tags = v.get('trainedWords', [])
+ if isinstance(possible_tags, list):
+ for tag_str in possible_tags:
+ for tag in tag_str.split(','):
+ tag = tag.strip().lower()
+ if tag not in tags:
+ tags[tag] = 0
+
+ possible_tags = info.get('tags', []) # tags from info json
+ if not isinstance(possible_tags, list):
+ possible_tags = list(possible_tags.values())
+ for tag in possible_tags:
+ tag = tag.strip().lower()
+ if tag not in tags:
+ tags[tag] = 0
+ except Exception:
+ pass
+ bad_chars = [';', ':', '<', ">", "*", '?', '\'', '\"', '(', ')', '[', ']', '{', '}', '\\', '/']
+ clean_tags = {}
+ for k, v in tags.items():
+ tag = ''.join(i for i in k if i not in bad_chars).strip()
+ clean_tags[tag] = v
+
+ clean_tags.pop('img', None)
+ clean_tags.pop('dataset', None)
+ return clean_tags
+
+ def create_item(self, name):
+ l = networks.available_networks.get(name)
+ if l is None:
+ shared.log.warning(f'Networks: type=lora registered={len(list(networks.available_networks))} file="{name}" not registered')
+ return None
+ try:
+ # path, _ext = os.path.splitext(l.filename)
+ name = os.path.splitext(os.path.relpath(l.filename, shared.cmd_opts.lora_dir))[0]
+ item = {
+ "type": 'Lora',
+ "name": name,
+ "filename": l.filename,
+ "hash": l.shorthash,
+ "prompt": json.dumps(f" "),
+ "metadata": json.dumps(l.metadata, indent=4) if l.metadata else None,
+ "mtime": os.path.getmtime(l.filename),
+ "size": os.path.getsize(l.filename),
+ "version": l.sd_version,
+ }
+ info = self.find_info(l.filename)
+ item["info"] = info
+ item["description"] = self.find_description(l.filename, info) # use existing info instead of double-read
+ item["tags"] = self.get_tags(l, info)
+ return item
+ except Exception as e:
+ shared.log.error(f'Networks: type=lora file="{name}" {e}')
+ if debug:
+ from modules import errors
+ errors.display(e, 'Lora')
+ return None
+
+ def list_items(self):
+ items = []
+ with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
+ future_items = {executor.submit(self.create_item, net): net for net in networks.available_networks}
+ for future in concurrent.futures.as_completed(future_items):
+ item = future.result()
+ if item is not None:
+ items.append(item)
+ self.update_all_previews(items)
+ return items
+
+ def allowed_directories_for_previews(self):
+ return [shared.cmd_opts.lora_dir]
diff --git a/modules/ui_img2img.py b/modules/ui_img2img.py
index 22c89dac8..45c901c6d 100644
--- a/modules/ui_img2img.py
+++ b/modules/ui_img2img.py
@@ -1,7 +1,6 @@
import os
from PIL import Image
import gradio as gr
-import numpy as np
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call
from modules import timer, shared, ui_common, ui_sections, generation_parameters_copypaste, processing_vae
@@ -56,7 +55,7 @@ def copy_image(img):
def add_copy_image_controls(tab_name, elem):
with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
- for title, name in zip(['➠ Image', '➠ Sketch', '➠ Inpaint', '➠ Composite'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
+ for title, name in zip(['➠ Image', '➠ Inpaint', '➠ Sketch', '➠ Composite'], ['img2img', 'sketch', 'inpaint', 'composite']):
if name == tab_name:
gr.Button(title, elem_id=f'copy_to_{name}', interactive=False)
copy_image_destinations[name] = elem
@@ -67,45 +66,47 @@ def add_copy_image_controls(tab_name, elem):
with gr.Tabs(elem_id="mode_img2img"):
img2img_selected_tab = gr.State(0) # pylint: disable=abstract-class-instantiated
state = gr.Textbox(value='', visible=False)
- with gr.TabItem('Image', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=512)
+ with gr.TabItem('Image', id='img2img_image', elem_id="img2img_image_tab") as tab_img2img:
+ img_init = gr.Image(label="", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=512)
interrogate_clip, interrogate_booru = ui_sections.create_interrogate_buttons('img2img')
- add_copy_image_controls('img2img', init_img)
-
- with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
- sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=512)
- add_copy_image_controls('sketch', sketch)
-
- with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=512)
- add_copy_image_controls('inpaint', init_img_with_mask)
-
- with gr.TabItem('Composite', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
- inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=512)
- inpaint_color_sketch_orig = gr.State(None) # pylint: disable=abstract-class-instantiated
- add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
-
- def update_orig(image, state):
- if image is not None:
- same_size = state is not None and state.size == image.size
- has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
- edited = same_size and has_exact_match
- return image if not edited or state is None else state
- return state
-
- inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
+ add_copy_image_controls('img2img', img_init)
+
+ with gr.TabItem('Inpaint', id='img2img_inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
+ img_inpaint = gr.Image(label="", elem_id="img2img_inpaint", show_label=False, source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=512)
+ add_copy_image_controls('inpaint', img_inpaint)
+
+ with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_sketch_tab") as tab_sketch:
+ img_sketch = gr.Image(label="", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=512)
+ add_copy_image_controls('sketch', img_sketch)
+
+ with gr.TabItem('Composite', id='img2img_composite', elem_id="img2img_composite_tab") as tab_inpaint_color:
+ img_composite = gr.Image(label="", show_label=False, elem_id="img2img_composite", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=512)
+ img_composite_orig = gr.State(None) # pylint: disable=abstract-class-instantiated
+ img_composite_orig_update = False
+
+ def fn_img_composite_upload():
+ nonlocal img_composite_orig_update
+ img_composite_orig_update = True
+ def fn_img_composite_change(img, img_composite):
+ nonlocal img_composite_orig_update
+ res = img if img_composite_orig_update else img_composite
+ img_composite_orig_update = False
+ return res
+
+ img_composite.upload(fn=fn_img_composite_upload, inputs=[], outputs=[])
+ img_composite.change(fn=fn_img_composite_change, inputs=[img_composite, img_composite_orig], outputs=[img_composite_orig])
+ add_copy_image_controls('composite', img_composite)
with gr.TabItem('Upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask")
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
- hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"Upload images or process images in a directory
Add inpaint batch mask directory to enable inpaint batch processing {hidden}
")
+ gr.HTML("Run image processing on upload images or files in a folder
If masks are provided will run inpaint
")
img2img_batch_files = gr.Files(label="Batch Process", interactive=True, elem_id="img2img_image_batch")
- img2img_batch_input_dir = gr.Textbox(label="Inpaint batch input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
- img2img_batch_output_dir = gr.Textbox(label="Inpaint batch output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
- img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
+ img2img_batch_input_dir = gr.Textbox(label="Batch input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
+ img2img_batch_output_dir = gr.Textbox(label="Batch output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
+ img2img_batch_inpaint_mask_dir = gr.Textbox(label="Batch mask directory", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
for i, tab in enumerate(img2img_tabs):
@@ -120,13 +121,13 @@ def update_orig(image, state):
with gr.Accordion(open=False, label="Sampler", elem_classes=["small-accordion"], elem_id="img2img_sampler_group"):
steps, sampler_index = ui_sections.create_sampler_and_steps_selection(None, "img2img")
ui_sections.create_sampler_options('img2img')
- resize_mode, resize_name, resize_context, width, height, scale_by, selected_scale_tab = ui_sections.create_resize_inputs('img2img', [init_img, sketch], latent=True, non_zero=False)
+ resize_mode, resize_name, resize_context, width, height, scale_by, selected_scale_tab = ui_sections.create_resize_inputs('img2img', [img_init, img_sketch], latent=True, non_zero=False)
batch_count, batch_size = ui_sections.create_batch_inputs('img2img', accordion=True)
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = ui_sections.create_seed_inputs('img2img')
with gr.Accordion(open=False, label="Denoise", elem_classes=["small-accordion"], elem_id="img2img_denoise_group"):
with gr.Row():
- denoising_strength = gr.Slider(minimum=0.0, maximum=0.99, step=0.01, label='Denoising strength', value=0.50, elem_id="img2img_denoising_strength")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=0.99, step=0.01, label='Denoising strength', value=0.30, elem_id="img2img_denoising_strength")
refiner_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Denoise start', value=0.0, elem_id="img2img_refiner_start")
full_quality, tiling, hidiffusion, cfg_scale, clip_skip, image_cfg_scale, diffusers_guidance_rescale, pag_scale, pag_adaptive, cfg_end = ui_sections.create_advanced_inputs('img2img')
@@ -167,13 +168,8 @@ def select_img2img_tab(tab):
img2img_args = [
dummy_component1, state, dummy_component2,
img2img_prompt, img2img_negative_prompt, img2img_prompt_styles,
- init_img,
- sketch,
- init_img_with_mask,
- inpaint_color_sketch,
- inpaint_color_sketch_orig,
- init_img_inpaint,
- init_mask_inpaint,
+ img_init, img_sketch, img_inpaint, img_composite, img_composite_orig,
+ init_img_inpaint, init_mask_inpaint,
steps,
sampler_index,
mask_blur, mask_alpha,
@@ -225,10 +221,7 @@ def select_img2img_tab(tab):
img2img_batch_files,
img2img_batch_input_dir,
img2img_batch_output_dir,
- init_img,
- sketch,
- init_img_with_mask,
- inpaint_color_sketch,
+ img_init, img_sketch, img_inpaint, img_composite,
init_img_inpaint,
],
outputs=[img2img_prompt, dummy_component],
@@ -285,7 +278,8 @@ def select_img2img_tab(tab):
(seed_resize_from_h, "Seed resize from-2"),
*modules.scripts.scripts_img2img.infotext_fields
]
- generation_parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
- generation_parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
+ generation_parameters_copypaste.add_paste_fields("img2img", img_init, img2img_paste_fields, override_settings)
+ generation_parameters_copypaste.add_paste_fields("sketch", img_sketch, img2img_paste_fields, override_settings)
+ generation_parameters_copypaste.add_paste_fields("inpaint", img_inpaint, img2img_paste_fields, override_settings)
img2img_bindings = generation_parameters_copypaste.ParamBinding(paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None)
generation_parameters_copypaste.register_paste_params_button(img2img_bindings)
diff --git a/modules/ui_models.py b/modules/ui_models.py
index 624c3849d..7ab8b0d07 100644
--- a/modules/ui_models.py
+++ b/modules/ui_models.py
@@ -8,7 +8,7 @@
from modules.ui_components import ToolButton
from modules.ui_common import create_refresh_button
from modules.call_queue import wrap_gradio_gpu_call
-from modules.shared import opts, log, req, readfile, max_workers
+from modules.shared import opts, log, req, readfile, max_workers, native
import modules.ui_symbols
import modules.errors
import modules.hashes
@@ -794,6 +794,10 @@ def civit_update_download():
civit_results4.select(fn=civit_update_select, inputs=[civit_results4], outputs=[models_outcome, civit_update_download_btn])
civit_update_download_btn.click(fn=civit_update_download, inputs=[], outputs=[models_outcome])
+ if native:
+ from modules.lora.lora_extract import create_ui as lora_extract_ui
+ lora_extract_ui()
+
for ui in extra_ui:
if callable(ui):
ui()
diff --git a/modules/ui_sections.py b/modules/ui_sections.py
index f15edb4bd..fcf53cf70 100644
--- a/modules/ui_sections.py
+++ b/modules/ui_sections.py
@@ -276,11 +276,11 @@ def set_sampler_preset(preset):
else: # shared.native
with gr.Row(elem_classes=['flex-break']):
- sampler_sigma = gr.Dropdown(label='Sigma method', elem_id=f"{tabname}_sampler_sigma", choices=['default', 'karras', 'betas', 'exponential', 'lambdas'], value=shared.opts.schedulers_sigma, type='value')
+ sampler_sigma = gr.Dropdown(label='Sigma method', elem_id=f"{tabname}_sampler_sigma", choices=['default', 'karras', 'betas', 'exponential', 'lambdas', 'flowmatch'], value=shared.opts.schedulers_sigma, type='value')
sampler_spacing = gr.Dropdown(label='Timestep spacing', elem_id=f"{tabname}_sampler_spacing", choices=['default', 'linspace', 'leading', 'trailing'], value=shared.opts.schedulers_timestep_spacing, type='value')
with gr.Row(elem_classes=['flex-break']):
sampler_beta = gr.Dropdown(label='Beta schedule', elem_id=f"{tabname}_sampler_beta", choices=['default', 'linear', 'scaled', 'cosine'], value=shared.opts.schedulers_beta_schedule, type='value')
- sampler_prediction = gr.Dropdown(label='Prediction method', elem_id=f"{tabname}_sampler_prediction", choices=['default', 'epsilon', 'sample', 'v_prediction'], value=shared.opts.schedulers_prediction_type, type='value')
+ sampler_prediction = gr.Dropdown(label='Prediction method', elem_id=f"{tabname}_sampler_prediction", choices=['default', 'epsilon', 'sample', 'v_prediction', 'flow_prediction'], value=shared.opts.schedulers_prediction_type, type='value')
with gr.Row(elem_classes=['flex-break']):
sampler_presets = gr.Dropdown(label='Timesteps presets', elem_id=f"{tabname}_sampler_presets", choices=['None', 'AYS SD15', 'AYS SDXL'], value='None', type='value')
sampler_timesteps = gr.Textbox(label='Timesteps override', elem_id=f"{tabname}_sampler_timesteps", value=shared.opts.schedulers_timesteps)
diff --git a/modules/xadapter/adapter.py b/modules/xadapter/adapter.py
index 69030fae3..4096f71e7 100644
--- a/modules/xadapter/adapter.py
+++ b/modules/xadapter/adapter.py
@@ -266,8 +266,6 @@ def forward(self, x, t=None):
b, c, _, _ = x[-1].shape
if t is not None:
if not torch.is_tensor(t):
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
- # This would be a good case for the `match` statement (Python 3.10+)
is_mps = x[0].device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
diff --git a/modules/xadapter/unet_adapter.py b/modules/xadapter/unet_adapter.py
index 5022f1847..5890c7749 100644
--- a/modules/xadapter/unet_adapter.py
+++ b/modules/xadapter/unet_adapter.py
@@ -807,8 +807,6 @@ def forward(
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
- # This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
@@ -1012,7 +1010,7 @@ def forward(
if is_bridge:
if up_block_additional_residual[0].shape != sample.shape:
- pass # TODO VM patch
+ pass
elif fusion_guidance_scale is not None:
sample = sample + fusion_guidance_scale * (up_block_additional_residual.pop(0) - sample)
else:
@@ -1051,7 +1049,7 @@ def forward(
################# bridge usage #################
if is_bridge and len(up_block_additional_residual) > 0:
if sample.shape != up_block_additional_residual[0].shape:
- pass # TODO VM PATCH
+ pass
elif fusion_guidance_scale is not None:
sample = sample + fusion_guidance_scale * (up_block_additional_residual.pop(0) - sample)
else:
diff --git a/modules/zluda_installer.py b/modules/zluda_installer.py
index 84c130e8d..5e43a6635 100644
--- a/modules/zluda_installer.py
+++ b/modules/zluda_installer.py
@@ -31,7 +31,10 @@ def install(zluda_path: os.PathLike) -> None:
if os.path.exists(zluda_path):
return
- urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{os.environ.get("ZLUDA_HASH", "c0804ca624963aab420cb418412b1c7fbae3454b")}/ZLUDA-windows-rocm{rocm.version[0]}-amd64.zip', '_zluda')
+ commit = os.environ.get("ZLUDA_HASH", "1b6e012d8f2404840b524e2abae12cb91e1ac01d")
+ if rocm.version == "6.1":
+ commit = "c0804ca624963aab420cb418412b1c7fbae3454b"
+ urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/rel.{commit}/ZLUDA-windows-rocm{rocm.version[0]}-amd64.zip', '_zluda')
with zipfile.ZipFile('_zluda', 'r') as archive:
infos = archive.infolist()
for info in infos:
diff --git a/requirements.txt b/requirements.txt
index 12a9f85cb..a67ad1ba2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -34,25 +34,25 @@ pi-heif
# versioned
safetensors==0.4.5
tensordict==0.1.2
-peft==0.13.1
+peft==0.14.0
httpx==0.24.1
compel==2.0.3
torchsde==0.2.6
antlr4-python3-runtime==4.9.3
requests==2.32.3
tqdm==4.66.5
-accelerate==1.1.1
+accelerate==1.2.1
opencv-contrib-python-headless==4.9.0.80
einops==0.4.1
gradio==3.43.2
-huggingface_hub==0.26.2
+huggingface_hub==0.27.0
numexpr==2.8.8
numpy==1.26.4
numba==0.59.1
protobuf==4.25.3
pytorch_lightning==1.9.4
-tokenizers==0.20.3
-transformers==4.46.2
+tokenizers==0.21.0
+transformers==4.47.1
urllib3==1.26.19
Pillow==10.4.0
timm==0.9.16
diff --git a/scripts/animatediff.py b/scripts/animatediff.py
index 4c50f9cf6..f44c85bb7 100644
--- a/scripts/animatediff.py
+++ b/scripts/animatediff.py
@@ -182,14 +182,14 @@ def set_free_init(method, iters, order, spatial, temporal):
def set_free_noise(frames):
context_length = 16
context_stride = 4
- if frames >= context_length:
+ if frames >= context_length and hasattr(shared.sd_model, 'enable_free_noise'):
shared.log.debug(f'AnimateDiff free noise: frames={frames} context={context_length} stride={context_stride}')
shared.sd_model.enable_free_noise(context_length=context_length, context_stride=context_stride)
class Script(scripts.Script):
def title(self):
- return 'Video AnimateDiff'
+ return 'Video: AnimateDiff'
def show(self, is_img2img):
# return scripts.AlwaysVisible if shared.native else False
@@ -250,7 +250,7 @@ def run(self, p: processing.StableDiffusionProcessing, adapter_index, frames, lo
processing.fix_seed(p)
p.extra_generation_params['AnimateDiff'] = loaded_adapter
p.do_not_save_grid = True
- p.ops.append('animatediff')
+ p.ops.append('video')
p.task_args['generator'] = None
p.task_args['num_frames'] = frames
p.task_args['num_inference_steps'] = p.steps
diff --git a/scripts/cogvideo.py b/scripts/cogvideo.py
index 7f2c7225e..e689a5e3f 100644
--- a/scripts/cogvideo.py
+++ b/scripts/cogvideo.py
@@ -22,7 +22,7 @@
class Script(scripts.Script):
def title(self):
- return 'Video CogVideoX'
+ return 'Video: CogVideoX'
def show(self, is_img2img):
return shared.native
@@ -51,7 +51,7 @@ def video_type_change(video_type):
with gr.Row():
video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None')
duration = gr.Slider(label='Duration', minimum=0.25, maximum=30, step=0.25, value=8, visible=False)
- with gr.Accordion('Optional init video', open=False):
+ with gr.Accordion('Optional init image or video', open=False):
with gr.Row():
image = gr.Image(value=None, label='Image', type='pil', source='upload', width=256, height=256)
video = gr.Video(value=None, label='Video', source='upload', width=256, height=256)
@@ -169,25 +169,18 @@ def generate(self, p: processing.StableDiffusionProcessing, model: str):
callback_on_step_end=diffusers_callback,
callback_on_step_end_tensor_inputs=['latents'],
)
- if getattr(p, 'image', False):
- if 'I2V' not in model:
- shared.log.error(f'CogVideoX: model={model} image input not supported')
- return []
- args['image'] = self.image(p, p.image)
- args['num_frames'] = p.frames # only txt2vid has num_frames
- shared.sd_model = sd_models.switch_pipe(diffusers.CogVideoXImageToVideoPipeline, shared.sd_model)
- elif getattr(p, 'video', False):
- if 'I2V' in model:
- shared.log.error(f'CogVideoX: model={model} image input not supported')
- return []
- args['video'] = self.video(p, p.video)
- shared.sd_model = sd_models.switch_pipe(diffusers.CogVideoXVideoToVideoPipeline, shared.sd_model)
+ if 'I2V' in model:
+ if hasattr(p, 'video') and p.video is not None:
+ args['video'] = self.video(p, p.video)
+ shared.sd_model = sd_models.switch_pipe(diffusers.CogVideoXVideoToVideoPipeline, shared.sd_model)
+ elif (hasattr(p, 'image') and p.image is not None) or (hasattr(p, 'init_images') and len(p.init_images) > 0):
+ p.init_images = [p.image] if hasattr(p, 'image') and p.image is not None else p.init_images
+ args['image'] = self.image(p, p.init_images[0])
+ shared.sd_model = sd_models.switch_pipe(diffusers.CogVideoXImageToVideoPipeline, shared.sd_model)
else:
- if 'I2V' in model:
- shared.log.error(f'CogVideoX: model={model} image input not supported')
- return []
- args['num_frames'] = p.frames # only txt2vid has num_frames
shared.sd_model = sd_models.switch_pipe(diffusers.CogVideoXPipeline, shared.sd_model)
+ args['num_frames'] = p.frames # only txt2vid has num_frames
+ shared.log.info(f"CogVideoX: class={shared.sd_model.__class__.__name__} frames={p.frames} input={args.get('video', None) or args.get('image', None)}")
if debug:
shared.log.debug(f'CogVideoX args: {args}')
frames = shared.sd_model(**args).frames[0]
@@ -199,7 +192,7 @@ def generate(self, p: processing.StableDiffusionProcessing, model: str):
errors.display(e, 'CogVideoX')
t1 = time.time()
its = (len(frames) * p.steps) / (t1 - t0)
- shared.log.info(f'CogVideoX: frames={len(frames)} its={its:.2f} time={t1 - t0:.2f}')
+ shared.log.info(f'CogVideoX: frame={frames[0] if len(frames) > 0 else None} frames={len(frames)} its={its:.2f} time={t1 - t0:.2f}')
return frames
# auto-executed by the script-callback
@@ -209,7 +202,7 @@ def run(self, p: processing.StableDiffusionProcessing, model, sampler, frames, g
p.extra_generation_params['CogVideoX'] = model
p.do_not_save_grid = True
if 'animatediff' not in p.ops:
- p.ops.append('cogvideox')
+ p.ops.append('video')
if override:
p.width = 720
p.height = 480
diff --git a/scripts/flux_tools.py b/scripts/flux_tools.py
new file mode 100644
index 000000000..50904eedb
--- /dev/null
+++ b/scripts/flux_tools.py
@@ -0,0 +1,149 @@
+# https://github.com/huggingface/diffusers/pull/9985
+
+import time
+import gradio as gr
+import diffusers
+from modules import scripts, processing, shared, devices, sd_models
+from installer import install
+
+
+# redux_pipe: diffusers.FluxPriorReduxPipeline = None
+redux_pipe = None
+processor_canny = None
+processor_depth = None
+title = 'Flux Tools'
+
+
+class Script(scripts.Script):
+ def title(self):
+ return f'{title}'
+
+ def show(self, is_img2img):
+ return is_img2img if shared.native else False
+
+ def ui(self, _is_img2img): # ui elements
+ with gr.Row():
+ gr.HTML('  Flux.1 Redux
')
+ with gr.Row():
+ tool = gr.Dropdown(label='Tool', choices=['None', 'Redux', 'Fill', 'Canny', 'Depth'], value='None')
+ with gr.Row():
+ prompt = gr.Slider(label='Redux prompt strength', minimum=0, maximum=2, step=0.01, value=0, visible=False)
+ process = gr.Checkbox(label='Control preprocess input images', value=True, visible=False)
+ strength = gr.Checkbox(label='Control override denoise strength', value=True, visible=False)
+
+ def display(tool):
+ return [
+ gr.update(visible=tool in ['Redux']),
+ gr.update(visible=tool in ['Canny', 'Depth']),
+ gr.update(visible=tool in ['Canny', 'Depth']),
+ ]
+
+ tool.change(fn=display, inputs=[tool], outputs=[prompt, process, strength])
+ return [tool, prompt, strength, process]
+
+ def run(self, p: processing.StableDiffusionProcessing, tool: str = 'None', prompt: float = 1.0, strength: bool = True, process: bool = True): # pylint: disable=arguments-differ
+ global redux_pipe, processor_canny, processor_depth # pylint: disable=global-statement
+ if tool is None or tool == 'None':
+ return
+ supported_model_list = ['f1']
+ if shared.sd_model_type not in supported_model_list:
+ shared.log.warning(f'{title}: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}')
+ return None
+ image = getattr(p, 'init_images', None)
+ if image is None or len(image) == 0:
+ shared.log.error(f'{title}: tool={tool} no init_images')
+ return None
+ else:
+ image = image[0] if isinstance(image, list) else image
+
+ shared.log.info(f'{title}: tool={tool} init')
+
+ t0 = time.time()
+ if tool == 'Redux':
+ # pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained("black-forest-labs/FLUX.1-Redux-dev", revision="refs/pr/8", torch_dtype=torch.bfloat16).to("cuda")
+ shared.log.debug(f'{title}: tool={tool} prompt={prompt}')
+ if redux_pipe is None:
+ redux_pipe = diffusers.FluxPriorReduxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Redux-dev",
+ revision="refs/pr/8",
+ torch_dtype=devices.dtype,
+ cache_dir=shared.opts.hfcache_dir
+ ).to(devices.device)
+ if prompt > 0:
+ shared.log.info(f'{title}: tool={tool} load text encoder')
+ redux_pipe.tokenizer, redux_pipe.tokenizer_2 = shared.sd_model.tokenizer, shared.sd_model.tokenizer_2
+ redux_pipe.text_encoder, redux_pipe.text_encoder_2 = shared.sd_model.text_encoder, shared.sd_model.text_encoder_2
+ sd_models.apply_balanced_offload(redux_pipe)
+ redux_output = redux_pipe(
+ image=image,
+ prompt=p.prompt if prompt > 0 else None,
+ prompt_embeds_scale=[prompt],
+ pooled_prompt_embeds_scale=[prompt],
+ )
+ if prompt > 0:
+ redux_pipe.tokenizer, redux_pipe.tokenizer_2 = None, None
+ redux_pipe.text_encoder, redux_pipe.text_encoder_2 = None, None
+ devices.torch_gc()
+ for k, v in redux_output.items():
+ p.task_args[k] = v
+ else:
+ if redux_pipe is not None:
+ shared.log.debug(f'{title}: tool=Redux unload')
+ redux_pipe = None
+
+ if tool == 'Fill':
+ # pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, revision="refs/pr/4").to("cuda")
+ if p.image_mask is None:
+ shared.log.error(f'{title}: tool={tool} no image_mask')
+ return None
+ if shared.sd_model.__class__.__name__ != 'FluxFillPipeline':
+ shared.opts.data["sd_model_checkpoint"] = "black-forest-labs/FLUX.1-Fill-dev"
+ sd_models.reload_model_weights(op='model', revision="refs/pr/4")
+ p.task_args['image'] = image
+ p.task_args['mask_image'] = p.image_mask
+
+ if tool == 'Canny':
+ # pipe = diffusers.FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16, revision="refs/pr/1").to("cuda")
+ install('controlnet-aux')
+ install('timm==0.9.16')
+ if shared.sd_model.__class__.__name__ != 'FluxControlPipeline' or 'Canny' not in shared.opts.sd_model_checkpoint:
+ shared.opts.data["sd_model_checkpoint"] = "black-forest-labs/FLUX.1-Canny-dev"
+ sd_models.reload_model_weights(op='model', revision="refs/pr/1")
+ if processor_canny is None:
+ from controlnet_aux import CannyDetector
+ processor_canny = CannyDetector()
+ if process:
+ control_image = processor_canny(image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)
+ p.task_args['control_image'] = control_image
+ else:
+ p.task_args['control_image'] = image
+ if strength:
+ p.task_args['strength'] = None
+ else:
+ if processor_canny is not None:
+ shared.log.debug(f'{title}: tool=Canny unload processor')
+ processor_canny = None
+
+ if tool == 'Depth':
+ # pipe = diffusers.FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16, revision="refs/pr/1").to("cuda")
+ install('git+https://github.com/huggingface/image_gen_aux.git', 'image_gen_aux')
+ if shared.sd_model.__class__.__name__ != 'FluxControlPipeline' or 'Depth' not in shared.opts.sd_model_checkpoint:
+ shared.opts.data["sd_model_checkpoint"] = "black-forest-labs/FLUX.1-Depth-dev"
+ sd_models.reload_model_weights(op='model', revision="refs/pr/1")
+ if processor_depth is None:
+ from image_gen_aux import DepthPreprocessor
+ processor_depth = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
+ if process:
+ control_image = processor_depth(image)[0].convert("RGB")
+ p.task_args['control_image'] = control_image
+ else:
+ p.task_args['control_image'] = image
+ if strength:
+ p.task_args['strength'] = None
+ else:
+ if processor_depth is not None:
+ shared.log.debug(f'{title}: tool=Depth unload processor')
+ processor_depth = None
+
+ shared.log.debug(f'{title}: tool={tool} ready time={time.time() - t0:.2f}')
+ devices.torch_gc()
diff --git a/scripts/freescale.py b/scripts/freescale.py
new file mode 100644
index 000000000..672ceea41
--- /dev/null
+++ b/scripts/freescale.py
@@ -0,0 +1,130 @@
+import gradio as gr
+from modules import scripts, processing, shared, sd_models
+
+
+registered = False
+
+
+class Script(scripts.Script):
+ def __init__(self):
+ super().__init__()
+ self.orig_pipe = None
+ self.orig_slice = None
+ self.orig_tile = None
+ self.is_img2img = False
+
+ def title(self):
+ return 'FreeScale: Tuning-Free Scale Fusion'
+
+ def show(self, is_img2img):
+ self.is_img2img = is_img2img
+ return shared.native
+
+ def ui(self, _is_img2img): # ui elements
+ with gr.Row():
+ gr.HTML('  FreeScale: Tuning-Free Scale Fusion
')
+ with gr.Row():
+ cosine_scale = gr.Slider(minimum=0.1, maximum=5.0, value=2.0, label='Cosine scale')
+ override_sampler = gr.Checkbox(value=True, label='Override sampler')
+ with gr.Row(visible=self.is_img2img):
+ cosine_scale_bg = gr.Slider(minimum=0.1, maximum=5.0, value=1.0, label='Cosine Background')
+ dilate_tau = gr.Slider(minimum=1, maximum=100, value=35, label='Dilate tau')
+ with gr.Row():
+ s1_enable = gr.Checkbox(value=True, label='1st Stage', interactive=False)
+ s1_scale = gr.Slider(minimum=1, maximum=8.0, value=1.0, label='Scale')
+ s1_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step')
+ with gr.Row():
+ s2_enable = gr.Checkbox(value=True, label='2nd Stage')
+ s2_scale = gr.Slider(minimum=1, maximum=8.0, value=2.0, label='Scale')
+ s2_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step')
+ with gr.Row():
+ s3_enable = gr.Checkbox(value=False, label='3rd Stage')
+ s3_scale = gr.Slider(minimum=1, maximum=8.0, value=3.0, label='Scale')
+ s3_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step')
+ with gr.Row():
+ s4_enable = gr.Checkbox(value=False, label='4th Stage')
+ s4_scale = gr.Slider(minimum=1, maximum=8.0, value=4.0, label='Scale')
+ s4_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step')
+ return [cosine_scale, override_sampler, cosine_scale_bg, dilate_tau, s1_enable, s1_scale, s1_restart, s2_enable, s2_scale, s2_restart, s3_enable, s3_scale, s3_restart, s4_enable, s4_scale, s4_restart]
+
+ def run(self, p: processing.StableDiffusionProcessing, cosine_scale, override_sampler, cosine_scale_bg, dilate_tau, s1_enable, s1_scale, s1_restart, s2_enable, s2_scale, s2_restart, s3_enable, s3_scale, s3_restart, s4_enable, s4_scale, s4_restart): # pylint: disable=arguments-differ
+ supported_model_list = ['sdxl']
+ if shared.sd_model_type not in supported_model_list:
+ shared.log.warning(f'FreeScale: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}')
+ return None
+
+ if self.is_img2img:
+ if p.init_images is None or len(p.init_images) == 0:
+ shared.log.warning('FreeScale: missing input image')
+ return None
+
+ from modules.freescale import StableDiffusionXLFreeScale, StableDiffusionXLFreeScaleImg2Img
+ self.orig_pipe = shared.sd_model
+ self.orig_slice = shared.opts.diffusers_vae_slicing
+ self.orig_tile = shared.opts.diffusers_vae_tiling
+
+ def scale(x):
+ if (p.width == 0 or p.height == 0) and p.init_images is not None:
+ p.width, p.height = p.init_images[0].width, p.init_images[0].height
+ resolution = [int(8 * p.width * x // 8), int(8 * p.height * x // 8)]
+ return resolution
+
+ scales = []
+ resolutions_list = []
+ restart_steps = []
+ if s1_enable:
+ scales.append(s1_scale)
+ resolutions_list.append(scale(s1_scale))
+ restart_steps.append(int(p.steps * s1_restart))
+ if s2_enable and s2_scale > s1_scale:
+ scales.append(s2_scale)
+ resolutions_list.append(scale(s2_scale))
+ restart_steps.append(int(p.steps * s2_restart))
+ if s3_enable and s3_scale > s2_scale:
+ scales.append(s3_scale)
+ resolutions_list.append(scale(s3_scale))
+ restart_steps.append(int(p.steps * s3_restart))
+ if s4_enable and s4_scale > s3_scale:
+ scales.append(s4_scale)
+ resolutions_list.append(scale(s4_scale))
+ restart_steps.append(int(p.steps * s4_restart))
+
+ p.task_args['resolutions_list'] = resolutions_list
+ p.task_args['cosine_scale'] = cosine_scale
+ p.task_args['restart_steps'] = [min(max(1, step), p.steps-1) for step in restart_steps]
+ if self.is_img2img:
+ p.task_args['cosine_scale_bg'] = cosine_scale_bg
+ p.task_args['dilate_tau'] = dilate_tau
+ p.task_args['img_path'] = p.init_images[0]
+ p.init_images = None
+ if override_sampler:
+ p.sampler_name = 'Euler a'
+
+ if p.width < 1024 or p.height < 1024:
+ shared.log.error(f'FreeScale: width={p.width} height={p.height} minimum=1024')
+ return None
+
+ if not self.is_img2img:
+ shared.sd_model = sd_models.switch_pipe(StableDiffusionXLFreeScale, shared.sd_model)
+ else:
+ shared.sd_model = sd_models.switch_pipe(StableDiffusionXLFreeScaleImg2Img, shared.sd_model)
+ shared.sd_model.enable_vae_slicing()
+ shared.sd_model.enable_vae_tiling()
+
+ shared.log.info(f'FreeScale: mode={"txt" if not self.is_img2img else "img"} cosine={cosine_scale} bg={cosine_scale_bg} tau={dilate_tau} scales={scales} resolutions={resolutions_list} steps={restart_steps} sampler={p.sampler_name}')
+ resolutions = ','.join([f'{x[0]}x{x[1]}' for x in resolutions_list])
+ steps = ','.join([str(x) for x in restart_steps])
+ p.extra_generation_params["FreeScale"] = f'cosine {cosine_scale} resolutions {resolutions} steps {steps}'
+
+ def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, *args): # pylint: disable=arguments-differ, unused-argument
+ if self.orig_pipe is None:
+ return processed
+ # restore pipeline
+ if shared.sd_model_type == "sdxl":
+ shared.sd_model = self.orig_pipe
+ self.orig_pipe = None
+ if not self.orig_slice:
+ shared.sd_model.disable_vae_slicing()
+ if not self.orig_tile:
+ shared.sd_model.disable_vae_tiling()
+ return processed
diff --git a/scripts/hunyuanvideo.py b/scripts/hunyuanvideo.py
new file mode 100644
index 000000000..b94c8b8f8
--- /dev/null
+++ b/scripts/hunyuanvideo.py
@@ -0,0 +1,111 @@
+import time
+import torch
+import gradio as gr
+import diffusers
+from modules import scripts, processing, shared, images, devices, sd_models, sd_checkpoint, model_quant
+
+
+repo_id = 'tencent/HunyuanVideo'
+"""
+prompt_template = { # default
+ "template": (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the contents, including objects, people, and anything else."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the contents."
+ "4. Background environment, light, style, atmosphere, and qualities."
+ "5. Camera angles, movements, and transitions used in the video."
+ "6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ ),
+ "crop_start": 95,
+}
+"""
+
+
+class Script(scripts.Script):
+ def title(self):
+ return 'Video: Hunyuan Video'
+
+ def show(self, is_img2img):
+ return not is_img2img if shared.native else False
+
+ # return signature is array of gradio components
+ def ui(self, _is_img2img):
+ def video_type_change(video_type):
+ return [
+ gr.update(visible=video_type != 'None'),
+ gr.update(visible=video_type == 'GIF' or video_type == 'PNG'),
+ gr.update(visible=video_type == 'MP4'),
+ gr.update(visible=video_type == 'MP4'),
+ ]
+
+ with gr.Row():
+ gr.HTML('  Hunyuan Video
')
+ with gr.Row():
+ num_frames = gr.Slider(label='Frames', minimum=9, maximum=257, step=1, value=45)
+ with gr.Row():
+ video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None')
+ duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False)
+ with gr.Row():
+ gif_loop = gr.Checkbox(label='Loop', value=True, visible=False)
+ mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False)
+ mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False)
+ video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate])
+ return [num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate]
+
+ def run(self, p: processing.StableDiffusionProcessing, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument
+ # set params
+ num_frames = int(num_frames)
+ p.width = 32 * int(p.width // 32)
+ p.height = 32 * int(p.height // 32)
+ p.task_args['output_type'] = 'pil'
+ p.task_args['generator'] = torch.manual_seed(p.seed)
+ p.task_args['num_frames'] = num_frames
+ # p.task_args['prompt_template'] = prompt_template
+ p.sampler_name = 'Default'
+ p.do_not_save_grid = True
+ p.ops.append('video')
+
+ # load model
+ cls = diffusers.HunyuanVideoPipeline
+ if shared.sd_model.__class__ != cls:
+ sd_models.unload_model_weights()
+ kwargs = {}
+ kwargs = model_quant.create_bnb_config(kwargs)
+ kwargs = model_quant.create_ao_config(kwargs)
+ transformer = diffusers.HunyuanVideoTransformer3DModel.from_pretrained(
+ repo_id,
+ subfolder="transformer",
+ torch_dtype=devices.dtype,
+ revision="refs/pr/18",
+ cache_dir = shared.opts.hfcache_dir,
+ **kwargs
+ )
+ shared.sd_model = cls.from_pretrained(
+ repo_id,
+ transformer=transformer,
+ revision="refs/pr/18",
+ cache_dir = shared.opts.hfcache_dir,
+ torch_dtype=devices.dtype,
+ **kwargs
+ )
+ shared.sd_model.scheduler._shift = 7.0 # pylint: disable=protected-access
+ sd_models.set_diffuser_options(shared.sd_model)
+ shared.sd_model.sd_checkpoint_info = sd_checkpoint.CheckpointInfo(repo_id)
+ shared.sd_model.sd_model_hash = None
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
+ shared.sd_model.vae.enable_slicing()
+ shared.sd_model.vae.enable_tiling()
+ devices.torch_gc(force=True)
+ shared.log.debug(f'Video: cls={shared.sd_model.__class__.__name__} args={p.task_args}')
+
+ # run processing
+ t0 = time.time()
+ processed = processing.process_images(p)
+ t1 = time.time()
+ if processed is not None and len(processed.images) > 0:
+ shared.log.info(f'Video: frames={len(processed.images)} time={t1-t0:.2f}')
+ if video_type != 'None':
+ images.save_video(p, filename=None, images=processed.images, video_type=video_type, duration=duration, loop=gif_loop, pad=mp4_pad, interpolate=mp4_interpolate)
+ return processed
diff --git a/scripts/image2video.py b/scripts/image2video.py
index 876ed3193..ad6615f67 100644
--- a/scripts/image2video.py
+++ b/scripts/image2video.py
@@ -13,7 +13,7 @@
class Script(scripts.Script):
def title(self):
- return 'Video VGen Image-to-Video'
+ return 'Video: VGen Image-to-Video'
def show(self, is_img2img):
return is_img2img if shared.native else False
@@ -73,7 +73,7 @@ def run(self, p: processing.StableDiffusionProcessing, model_name, num_frames, v
model = [m for m in MODELS if m['name'] == model_name][0]
repo_id = model['url']
shared.log.debug(f'Image2Video: model={model_name} frames={num_frames}, video={video_type} duration={duration} loop={gif_loop} pad={mp4_pad} interpolate={mp4_interpolate}')
- p.ops.append('image2video')
+ p.ops.append('video')
p.do_not_save_grid = True
orig_pipeline = shared.sd_model
diff --git a/scripts/instantir.py b/scripts/instantir.py
index 5eb7d503a..6ab7733fe 100644
--- a/scripts/instantir.py
+++ b/scripts/instantir.py
@@ -80,7 +80,7 @@ def run(self, p: processing.StableDiffusionProcessing, *args): # pylint: disable
devices.torch_gc()
def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, *args): # pylint: disable=arguments-differ, unused-argument
- # TODO instantir is a mess to unload
+ # TODO instantir: a mess to unload
"""
if self.orig_pipe is None:
return processed
diff --git a/scripts/ipadapter.py b/scripts/ipadapter.py
index 60c70b9dc..5ca4ca578 100644
--- a/scripts/ipadapter.py
+++ b/scripts/ipadapter.py
@@ -57,6 +57,7 @@ def ui(self, _is_img2img):
mask_galleries = []
with gr.Row():
num_adapters = gr.Slider(label="Active IP adapters", minimum=1, maximum=MAX_ADAPTERS, step=1, value=1, scale=1)
+ unload_adapter = gr.Checkbox(label='Unload adapter', value=False, interactive=True)
for i in range(MAX_ADAPTERS):
with gr.Accordion(f'Adapter {i+1}', visible=i==0) as unit:
with gr.Row():
@@ -85,7 +86,7 @@ def ui(self, _is_img2img):
layers_label = gr.HTML('InstantStyle: advanced layer activation', visible=False)
layers = gr.Text(label='Layer scales', placeholder='{\n"down": {"block_2": [0.0, 1.0]},\n"up": {"block_0": [0.0, 1.0, 0.0]}\n}', rows=1, type='text', interactive=True, lines=5, visible=False, show_label=False)
layers_active.change(fn=self.display_advanced, inputs=[layers_active], outputs=[layers_label, layers])
- return [num_adapters] + adapters + scales + files + crops + starts + ends + masks + [layers_active] + [layers]
+ return [num_adapters] + [unload_adapter] + adapters + scales + files + crops + starts + ends + masks + [layers_active] + [layers]
def process(self, p: processing.StableDiffusionProcessing, *args): # pylint: disable=arguments-differ
if not shared.native:
@@ -94,6 +95,7 @@ def process(self, p: processing.StableDiffusionProcessing, *args): # pylint: dis
if len(args) == 0:
return
units = args.pop(0)
+ unload = args.pop(0)
if getattr(p, 'ip_adapter_names', []) == []:
p.ip_adapter_names = args[:MAX_ADAPTERS][:units]
if getattr(p, 'ip_adapter_scales', [0.0]) == [0.0]:
@@ -110,6 +112,7 @@ def process(self, p: processing.StableDiffusionProcessing, *args): # pylint: dis
p.ip_adapter_masks = args[MAX_ADAPTERS*6:MAX_ADAPTERS*7][:units]
p.ip_adapter_masks = [x for x in p.ip_adapter_masks if x]
layers_active, layers = args[MAX_ADAPTERS*7:MAX_ADAPTERS*8]
+ p.ip_adapter_unload = unload
if layers_active and len(layers) > 0:
try:
layers = json.loads(layers)
diff --git a/scripts/ltxvideo.py b/scripts/ltxvideo.py
new file mode 100644
index 000000000..007c4f4cc
--- /dev/null
+++ b/scripts/ltxvideo.py
@@ -0,0 +1,148 @@
+import os
+import time
+import torch
+import gradio as gr
+import diffusers
+import transformers
+from modules import scripts, processing, shared, images, devices, sd_models, sd_checkpoint, model_quant
+
+
+repos = {
+ '0.9.0': 'a-r-r-o-w/LTX-Video-diffusers',
+ '0.9.1': 'a-r-r-o-w/LTX-Video-0.9.1-diffusers',
+ 'custom': None,
+}
+
+
+def load_quants(kwargs, repo_id):
+ if len(shared.opts.bnb_quantization) > 0:
+ quant_args = {}
+ quant_args = model_quant.create_bnb_config(quant_args)
+ quant_args = model_quant.create_ao_config(quant_args)
+ if not quant_args:
+ return kwargs
+ model_quant.load_bnb(f'Load model: type=LTX quant={quant_args}')
+ if 'Model' in shared.opts.bnb_quantization and 'transformer' not in kwargs:
+ kwargs['transformer'] = diffusers.LTXVideoTransformer3DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype, **quant_args)
+ shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
+ if 'Text Encoder' in shared.opts.bnb_quantization and 'text_encoder_3' not in kwargs:
+ kwargs['text_encoder'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder", cache_dir=shared.opts.hfcache_dir, torch_dtype=devices.dtype, **quant_args)
+ shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
+ return kwargs
+
+
+def hijack_decode(*args, **kwargs):
+ shared.log.debug('Video: decode')
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
+ return shared.sd_model.vae.orig_decode(*args, **kwargs)
+
+
+class Script(scripts.Script):
+ def title(self):
+ return 'Video: LTX Video'
+
+ def show(self, is_img2img):
+ return shared.native
+
+ # return signature is array of gradio components
+ def ui(self, _is_img2img):
+ def video_type_change(video_type):
+ return [
+ gr.update(visible=video_type != 'None'),
+ gr.update(visible=video_type == 'GIF' or video_type == 'PNG'),
+ gr.update(visible=video_type == 'MP4'),
+ gr.update(visible=video_type == 'MP4'),
+ ]
+ def model_change(model):
+ return gr.update(visible=model == 'custom')
+
+ with gr.Row():
+ gr.HTML('  LTX Video
')
+ with gr.Row():
+ model = gr.Dropdown(label='LTX Model', choices=list(repos), value='0.9.1')
+ decode = gr.Dropdown(label='Decode', choices=['diffusers', 'native'], value='diffusers', visible=False)
+ with gr.Row():
+ num_frames = gr.Slider(label='Frames', minimum=9, maximum=257, step=1, value=41)
+ sampler = gr.Checkbox(label='Override sampler', value=True)
+ with gr.Row():
+ model_custom = gr.Textbox(value='', label='Path to model file', visible=False)
+ with gr.Row():
+ video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None')
+ duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False)
+ with gr.Row():
+ gif_loop = gr.Checkbox(label='Loop', value=True, visible=False)
+ mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False)
+ mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False)
+ video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate])
+ model.change(fn=model_change, inputs=[model], outputs=[model_custom])
+ return [model, model_custom, decode, sampler, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate]
+
+ def run(self, p: processing.StableDiffusionProcessing, model, model_custom, decode, sampler, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument
+ # set params
+ image = getattr(p, 'init_images', None)
+ image = None if image is None or len(image) == 0 else image[0]
+ if (p.width == 0 or p.height == 0) and image is not None:
+ p.width = image.width
+ p.height = image.height
+ num_frames = 8 * int(num_frames // 8) + 1
+ p.width = 32 * int(p.width // 32)
+ p.height = 32 * int(p.height // 32)
+ processing.fix_seed(p)
+ if image:
+ image = images.resize_image(resize_mode=2, im=image, width=p.width, height=p.height, upscaler_name=None, output_type='pil')
+ p.task_args['image'] = image
+ p.task_args['output_type'] = 'latent' if decode == 'native' else 'pil'
+ p.task_args['generator'] = torch.Generator(devices.device).manual_seed(p.seed)
+ p.task_args['num_frames'] = num_frames
+ p.do_not_save_grid = True
+ if sampler:
+ p.sampler_name = 'Default'
+ p.ops.append('video')
+
+ # load model
+ cls = diffusers.LTXPipeline if image is None else diffusers.LTXImageToVideoPipeline
+ diffusers.LTXTransformer3DModel = diffusers.LTXVideoTransformer3DModel
+ diffusers.AutoencoderKLLTX = diffusers.AutoencoderKLLTXVideo
+ repo_id = repos[model]
+ if repo_id is None:
+ repo_id = model_custom
+ if shared.sd_model.__class__ != cls:
+ sd_models.unload_model_weights()
+ kwargs = {}
+ kwargs = model_quant.create_bnb_config(kwargs)
+ kwargs = model_quant.create_ao_config(kwargs)
+ if os.path.isfile(repo_id):
+ shared.sd_model = cls.from_single_file(
+ repo_id,
+ cache_dir = shared.opts.hfcache_dir,
+ torch_dtype=devices.dtype,
+ **kwargs
+ )
+ else:
+ kwargs = load_quants(kwargs, repo_id)
+ shared.sd_model = cls.from_pretrained(
+ repo_id,
+ cache_dir = shared.opts.hfcache_dir,
+ torch_dtype=devices.dtype,
+ **kwargs
+ )
+ sd_models.set_diffuser_options(shared.sd_model)
+ shared.sd_model.vae.orig_decode = shared.sd_model.vae.decode
+ shared.sd_model.vae.decode = hijack_decode
+ shared.sd_model.sd_checkpoint_info = sd_checkpoint.CheckpointInfo(repo_id)
+ shared.sd_model.sd_model_hash = None
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
+ shared.sd_model.vae.enable_slicing()
+ shared.sd_model.vae.enable_tiling()
+ devices.torch_gc(force=True)
+ shared.log.debug(f'Video: cls={shared.sd_model.__class__.__name__} args={p.task_args}')
+
+ # run processing
+ t0 = time.time()
+ processed = processing.process_images(p)
+ t1 = time.time()
+ if processed is not None and len(processed.images) > 0:
+ shared.log.info(f'Video: frames={len(processed.images)} time={t1-t0:.2f}')
+ if video_type != 'None':
+ images.save_video(p, filename=None, images=processed.images, video_type=video_type, duration=duration, loop=gif_loop, pad=mp4_pad, interpolate=mp4_interpolate)
+ return processed
diff --git a/scripts/mochivideo.py b/scripts/mochivideo.py
new file mode 100644
index 000000000..f85616a5e
--- /dev/null
+++ b/scripts/mochivideo.py
@@ -0,0 +1,85 @@
+import time
+import torch
+import gradio as gr
+import diffusers
+from modules import scripts, processing, shared, images, devices, sd_models, sd_checkpoint, model_quant
+
+
+repo_id = 'genmo/mochi-1-preview'
+
+
+class Script(scripts.Script):
+ def title(self):
+ return 'Video: Mochi.1 Video'
+
+ def show(self, is_img2img):
+ return not is_img2img if shared.native else False
+
+ # return signature is array of gradio components
+ def ui(self, _is_img2img):
+ def video_type_change(video_type):
+ return [
+ gr.update(visible=video_type != 'None'),
+ gr.update(visible=video_type == 'GIF' or video_type == 'PNG'),
+ gr.update(visible=video_type == 'MP4'),
+ gr.update(visible=video_type == 'MP4'),
+ ]
+
+ with gr.Row():
+ gr.HTML('  Mochi.1 Video
')
+ with gr.Row():
+ num_frames = gr.Slider(label='Frames', minimum=9, maximum=257, step=1, value=45)
+ with gr.Row():
+ video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None')
+ duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False)
+ with gr.Row():
+ gif_loop = gr.Checkbox(label='Loop', value=True, visible=False)
+ mp4_pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False)
+ mp4_interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False)
+ video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, gif_loop, mp4_pad, mp4_interpolate])
+ return [num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate]
+
+ def run(self, p: processing.StableDiffusionProcessing, num_frames, video_type, duration, gif_loop, mp4_pad, mp4_interpolate): # pylint: disable=arguments-differ, unused-argument
+ # set params
+ num_frames = int(num_frames // 8)
+ p.width = 32 * int(p.width // 32)
+ p.height = 32 * int(p.height // 32)
+ p.task_args['output_type'] = 'pil'
+ p.task_args['generator'] = torch.manual_seed(p.seed)
+ p.task_args['num_frames'] = num_frames
+ p.sampler_name = 'Default'
+ p.do_not_save_grid = True
+ p.ops.append('video')
+
+ # load model
+ cls = diffusers.MochiPipeline
+ if shared.sd_model.__class__ != cls:
+ sd_models.unload_model_weights()
+ kwargs = {}
+ kwargs = model_quant.create_bnb_config(kwargs)
+ kwargs = model_quant.create_ao_config(kwargs)
+ shared.sd_model = cls.from_pretrained(
+ repo_id,
+ cache_dir = shared.opts.hfcache_dir,
+ torch_dtype=devices.dtype,
+ **kwargs
+ )
+ shared.sd_model.scheduler._shift = 7.0 # pylint: disable=protected-access
+ sd_models.set_diffuser_options(shared.sd_model)
+ shared.sd_model.sd_checkpoint_info = sd_checkpoint.CheckpointInfo(repo_id)
+ shared.sd_model.sd_model_hash = None
+ shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
+ shared.sd_model.vae.enable_slicing()
+ shared.sd_model.vae.enable_tiling()
+ devices.torch_gc(force=True)
+ shared.log.debug(f'Video: cls={shared.sd_model.__class__.__name__} args={p.task_args}')
+
+ # run processing
+ t0 = time.time()
+ processed = processing.process_images(p)
+ t1 = time.time()
+ if processed is not None and len(processed.images) > 0:
+ shared.log.info(f'Video: frames={len(processed.images)} time={t1-t0:.2f}')
+ if video_type != 'None':
+ images.save_video(p, filename=None, images=processed.images, video_type=video_type, duration=duration, loop=gif_loop, pad=mp4_pad, interpolate=mp4_interpolate)
+ return processed
diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py
index 3e5fd451d..8ed677736 100644
--- a/scripts/postprocessing_codeformer.py
+++ b/scripts/postprocessing_codeformer.py
@@ -10,9 +10,10 @@ class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing
order = 3000
def ui(self):
- with gr.Row():
- codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Strength", value=0.0, elem_id="extras_codeformer_visibility")
- codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Weight", value=0.2, elem_id="extras_codeformer_weight")
+ with gr.Accordion('Restore faces: CodeFormer', open = False):
+ with gr.Row():
+ codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Strength", value=0.0, elem_id="extras_codeformer_visibility")
+ codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Weight", value=0.2, elem_id="extras_codeformer_weight")
return { "codeformer_visibility": codeformer_visibility, "codeformer_weight": codeformer_weight }
def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): # pylint: disable=arguments-differ
diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py
index a69f97c9e..1e17c2b16 100644
--- a/scripts/postprocessing_gfpgan.py
+++ b/scripts/postprocessing_gfpgan.py
@@ -9,8 +9,9 @@ class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
order = 2000
def ui(self):
- with gr.Row():
- gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Strength", value=0, elem_id="extras_gfpgan_visibility")
+ with gr.Accordion('Restore faces: GFPGan', open = False):
+ with gr.Row():
+ gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Strength", value=0, elem_id="extras_gfpgan_visibility")
return { "gfpgan_visibility": gfpgan_visibility }
def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): # pylint: disable=arguments-differ
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
index 0e07bc847..104a0fb37 100644
--- a/scripts/postprocessing_upscale.py
+++ b/scripts/postprocessing_upscale.py
@@ -10,43 +10,43 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
order = 1000
def ui(self):
- selected_tab = gr.State(value=0) # pylint: disable=abstract-class-instantiated
-
- with gr.Column():
- with gr.Row(elem_id="extras_upscale"):
- with gr.Tabs(elem_id="extras_resize_mode"):
- with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
- upscaling_resize = gr.Slider(minimum=0.1, maximum=8.0, step=0.05, label="Resize", value=2.0, elem_id="extras_upscaling_resize")
-
- with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
- with gr.Row():
- with gr.Row(elem_id="upscaling_column_size"):
- upscaling_resize_w = gr.Slider(minimum=64, maximum=4096, step=8, label="Width", value=1024, elem_id="extras_upscaling_resize_w")
- upscaling_resize_h = gr.Slider(minimum=64, maximum=4096, step=8, label="Height", value=1024, elem_id="extras_upscaling_resize_h")
- upscaling_res_switch_btn = ToolButton(value=symbols.switch, elem_id="upscaling_res_switch_btn")
- upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
-
- with gr.Row():
- extras_upscaler_1 = gr.Dropdown(label='Upscaler', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
-
- with gr.Row():
- extras_upscaler_2 = gr.Dropdown(label='Refine Upscaler', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
- extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
-
- upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False)
- tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
- tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
-
- return {
- "upscale_mode": selected_tab,
- "upscale_by": upscaling_resize,
- "upscale_to_width": upscaling_resize_w,
- "upscale_to_height": upscaling_resize_h,
- "upscale_crop": upscaling_crop,
- "upscaler_1_name": extras_upscaler_1,
- "upscaler_2_name": extras_upscaler_2,
- "upscaler_2_visibility": extras_upscaler_2_visibility,
- }
+ with gr.Accordion('Postprocess Upscale', open = False):
+ selected_tab = gr.State(value=0) # pylint: disable=abstract-class-instantiated
+ with gr.Column():
+ with gr.Row(elem_id="extras_upscale"):
+ with gr.Tabs(elem_id="extras_resize_mode"):
+ with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
+ upscaling_resize = gr.Slider(minimum=0.1, maximum=8.0, step=0.05, label="Resize", value=2.0, elem_id="extras_upscaling_resize")
+
+ with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
+ with gr.Row():
+ with gr.Row(elem_id="upscaling_column_size"):
+ upscaling_resize_w = gr.Slider(minimum=64, maximum=4096, step=8, label="Width", value=1024, elem_id="extras_upscaling_resize_w")
+ upscaling_resize_h = gr.Slider(minimum=64, maximum=4096, step=8, label="Height", value=1024, elem_id="extras_upscaling_resize_h")
+ upscaling_res_switch_btn = ToolButton(value=symbols.switch, elem_id="upscaling_res_switch_btn")
+ upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
+
+ with gr.Row():
+ extras_upscaler_1 = gr.Dropdown(label='Upscaler', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
+
+ with gr.Row():
+ extras_upscaler_2 = gr.Dropdown(label='Refine Upscaler', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
+ extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
+
+ upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False)
+ tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
+ tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
+
+ return {
+ "upscale_mode": selected_tab,
+ "upscale_by": upscaling_resize,
+ "upscale_to_width": upscaling_resize_w,
+ "upscale_to_height": upscaling_resize_h,
+ "upscale_crop": upscaling_crop,
+ "upscaler_1_name": extras_upscaler_1,
+ "upscaler_2_name": extras_upscaler_2,
+ "upscaler_2_visibility": extras_upscaler_2_visibility,
+ }
def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop):
if upscale_mode == 1:
diff --git a/scripts/postprocessing_video.py b/scripts/postprocessing_video.py
index 8c266639c..75d375572 100644
--- a/scripts/postprocessing_video.py
+++ b/scripts/postprocessing_video.py
@@ -7,40 +7,41 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
name = "Video"
def ui(self):
- def video_type_change(video_type):
- return [
- gr.update(visible=video_type != 'None'),
- gr.update(visible=video_type == 'GIF' or video_type == 'PNG'),
- gr.update(visible=video_type == 'MP4'),
- gr.update(visible=video_type == 'MP4'),
- gr.update(visible=video_type == 'MP4'),
- gr.update(visible=video_type == 'MP4'),
- ]
+ with gr.Accordion('Create video', open = False):
+ def video_type_change(video_type):
+ return [
+ gr.update(visible=video_type != 'None'),
+ gr.update(visible=video_type == 'GIF' or video_type == 'PNG'),
+ gr.update(visible=video_type == 'MP4'),
+ gr.update(visible=video_type == 'MP4'),
+ gr.update(visible=video_type == 'MP4'),
+ gr.update(visible=video_type == 'MP4'),
+ ]
- with gr.Row():
- gr.HTML("  Video
")
- with gr.Row():
- video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None', elem_id="extras_video_type")
- duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False, elem_id="extras_video_duration")
- with gr.Row():
- loop = gr.Checkbox(label='Loop', value=True, visible=False, elem_id="extras_video_loop")
- pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False, elem_id="extras_video_pad")
- interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False, elem_id="extras_video_interpolate")
- scale = gr.Slider(label='Rescale', minimum=0.5, maximum=2, step=0.05, value=1, visible=False, elem_id="extras_video_scale")
- change = gr.Slider(label='Frame change sensitivity', minimum=0, maximum=1, step=0.05, value=0.3, visible=False, elem_id="extras_video_change")
- with gr.Row():
- filename = gr.Textbox(label='Filename', placeholder='enter filename', lines=1, elem_id="extras_video_filename")
- video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, loop, pad, interpolate, scale, change])
- return {
- "filename": filename,
- "video_type": video_type,
- "duration": duration,
- "loop": loop,
- "pad": pad,
- "interpolate": interpolate,
- "scale": scale,
- "change": change,
- }
+ with gr.Row():
+ gr.HTML("  Video
")
+ with gr.Row():
+ video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None', elem_id="extras_video_type")
+ duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False, elem_id="extras_video_duration")
+ with gr.Row():
+ loop = gr.Checkbox(label='Loop', value=True, visible=False, elem_id="extras_video_loop")
+ pad = gr.Slider(label='Pad frames', minimum=0, maximum=24, step=1, value=1, visible=False, elem_id="extras_video_pad")
+ interpolate = gr.Slider(label='Interpolate frames', minimum=0, maximum=24, step=1, value=0, visible=False, elem_id="extras_video_interpolate")
+ scale = gr.Slider(label='Rescale', minimum=0.5, maximum=2, step=0.05, value=1, visible=False, elem_id="extras_video_scale")
+ change = gr.Slider(label='Frame change sensitivity', minimum=0, maximum=1, step=0.05, value=0.3, visible=False, elem_id="extras_video_change")
+ with gr.Row():
+ filename = gr.Textbox(label='Filename', placeholder='enter filename', lines=1, elem_id="extras_video_filename")
+ video_type.change(fn=video_type_change, inputs=[video_type], outputs=[duration, loop, pad, interpolate, scale, change])
+ return {
+ "filename": filename,
+ "video_type": video_type,
+ "duration": duration,
+ "loop": loop,
+ "pad": pad,
+ "interpolate": interpolate,
+ "scale": scale,
+ "change": change,
+ }
def postprocess(self, images, filename, video_type, duration, loop, pad, interpolate, scale, change): # pylint: disable=arguments-differ
filename = filename.strip() if filename is not None else ''
diff --git a/scripts/pulid_ext.py b/scripts/pulid_ext.py
index 676fa79f3..ee08e348b 100644
--- a/scripts/pulid_ext.py
+++ b/scripts/pulid_ext.py
@@ -164,11 +164,13 @@ def run(
p.batch_size = 1
sdp = shared.opts.cross_attention_optimization == "Scaled-Dot-Product"
+ sampler_fn = getattr(self.pulid.sampling, f'sample_{sampler}', None)
strength = getattr(p, 'pulid_strength', strength)
zero = getattr(p, 'pulid_zero', zero)
ortho = getattr(p, 'pulid_ortho', ortho)
sampler = getattr(p, 'pulid_sampler', sampler)
- sampler_fn = getattr(self.pulid.sampling, f'sample_{sampler}', None)
+ restore = getattr(p, 'pulid_restore', restore)
+ p.pulid_restore = restore
if sampler_fn is None:
sampler_fn = self.pulid.sampling.sample_dpmpp_2m_sde
@@ -199,7 +201,7 @@ def run(
return None
shared.sd_model.sampler = sampler_fn
- shared.log.info(f'PuLID: class={shared.sd_model.__class__.__name__} version="{version}" sdp={sdp} strength={strength} zero={zero} ortho={ortho} sampler={sampler_fn} images={[i.shape for i in images]} offload={offload}')
+ shared.log.info(f'PuLID: class={shared.sd_model.__class__.__name__} version="{version}" sdp={sdp} strength={strength} zero={zero} ortho={ortho} sampler={sampler_fn} images={[i.shape for i in images]} offload={offload} restore={restore}')
self.pulid.attention.NUM_ZERO = zero
self.pulid.attention.ORTHO = ortho == 'v1'
self.pulid.attention.ORTHO_v2 = ortho == 'v2'
diff --git a/scripts/regional_prompting.py b/scripts/regional_prompting.py
index 08b84dd94..3d452b0c5 100644
--- a/scripts/regional_prompting.py
+++ b/scripts/regional_prompting.py
@@ -9,7 +9,7 @@
def hijack_register_modules(self, **kwargs):
for name, module in kwargs.items():
register_dict = None
- if module is None or isinstance(module, (tuple, list)) and module[0] is None:
+ if module is None or (isinstance(module, (tuple, list)) and module[0] is None):
register_dict = {name: (None, None)}
elif isinstance(module, bool):
pass
@@ -82,6 +82,7 @@ def run(self, p: processing.StableDiffusionProcessing, mode, grid, power, thresh
}
# run pipeline
shared.log.debug(f'Regional: args={p.task_args}')
+ p.task_args['prompt'] = p.prompt
processed: processing.Processed = processing.process_images(p) # runs processing using main loop
# restore pipeline and params
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
index 9c5a72204..7ac31b603 100644
--- a/scripts/sd_upscale.py
+++ b/scripts/sd_upscale.py
@@ -48,7 +48,7 @@ def run(self, p, _, overlap, upscaler_index, scale_factor): # pylint: disable=ar
else:
img = init_img
devices.torch_gc()
- grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap)
+ grid = images.split_grid(img, tile_w=init_img.width, tile_h=init_img.height, overlap=overlap)
batch_size = p.batch_size
upscale_count = p.n_iter
p.n_iter = 1
@@ -61,7 +61,7 @@ def run(self, p, _, overlap, upscaler_index, scale_factor): # pylint: disable=ar
batch_count = math.ceil(len(work) / batch_size)
state.job_count = batch_count * upscale_count
- log.info(f"SD upscale: images={len(work)} tile={len(grid.tiles[0][2])}x{len(grid.tiles)} batches={state.job_count}")
+ log.info(f"SD upscale: images={len(work)} tiles={len(grid.tiles)} batches={state.job_count}")
result_images = []
for n in range(upscale_count):
@@ -91,4 +91,5 @@ def run(self, p, _, overlap, upscaler_index, scale_factor): # pylint: disable=ar
images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
processed = Processed(p, result_images, seed, initial_info)
+ log.info(f"SD upscale: images={result_images}")
return processed
diff --git a/scripts/stablevideodiffusion.py b/scripts/stablevideodiffusion.py
index cbf2ce003..c1283e1b6 100644
--- a/scripts/stablevideodiffusion.py
+++ b/scripts/stablevideodiffusion.py
@@ -16,7 +16,7 @@
class Script(scripts.Script):
def title(self):
- return 'Video: SVD'
+ return 'Video: Stable Video Diffusion'
def show(self, is_img2img):
return is_img2img if shared.native else False
@@ -75,13 +75,17 @@ def run(self, p: processing.StableDiffusionProcessing, model, num_frames, overri
if model_name != model_loaded or c != 'StableVideoDiffusionPipeline':
shared.opts.sd_model_checkpoint = model_path
sd_models.reload_model_weights()
+ shared.sd_model = shared.sd_model.to(torch.float32) # TODO svd: runs in fp32 causing dtype mismatch
# set params
if override_resolution:
p.width = 1024
p.height = 576
image = images.resize_image(resize_mode=2, im=image, width=p.width, height=p.height, upscaler_name=None, output_type='pil')
- p.ops.append('svd')
+ else:
+ p.width = image.width
+ p.height = image.height
+ p.ops.append('video')
p.do_not_save_grid = True
p.init_images = [image]
p.sampler_name = 'Default' # svd does not support non-default sampler
diff --git a/scripts/style_aligned.py b/scripts/style_aligned.py
new file mode 100644
index 000000000..25feb49bc
--- /dev/null
+++ b/scripts/style_aligned.py
@@ -0,0 +1,117 @@
+import gradio as gr
+import torch
+import numpy as np
+import diffusers
+from modules import scripts, processing, shared, devices
+
+
+handler = None
+zts = None
+supported_model_list = ['sdxl']
+orig_prompt_attention = None
+
+
+class Script(scripts.Script):
+ def title(self):
+ return 'Style Aligned Image Generation'
+
+ def show(self, is_img2img):
+ return shared.native
+
+ def reset(self):
+ global handler, zts # pylint: disable=global-statement
+ handler = None
+ zts = None
+ shared.log.info('SA: image upload')
+
+ def preset(self, preset):
+ if preset == 'text':
+ return [['attention', 'adain_queries', 'adain_keys'], 1.0, 0, 0.0]
+ if preset == 'image':
+ return [['group_norm', 'layer_norm', 'attention', 'adain_queries', 'adain_keys'], 1.0, 2, 0.0]
+ if preset == 'all':
+ return [['group_norm', 'layer_norm', 'attention', 'adain_queries', 'adain_keys', 'adain_values', 'full_attention_share'], 1.0, 1, 0.5]
+
+ def ui(self, _is_img2img): # ui elements
+ with gr.Row():
+ gr.HTML('  Style Aligned Image Generation
')
+ with gr.Row():
+ preset = gr.Dropdown(label="Preset", choices=['text', 'image', 'all'], value='text')
+ scheduler = gr.Checkbox(label="Override scheduler", value=False)
+ with gr.Row():
+ shared_opts = gr.Dropdown(label="Shared options",
+ multiselect=True,
+ choices=['group_norm', 'layer_norm', 'attention', 'adain_queries', 'adain_keys', 'adain_values', 'full_attention_share'],
+ value=['attention', 'adain_queries', 'adain_keys'],
+ )
+ with gr.Row():
+ shared_score_scale = gr.Slider(label="Scale", minimum=0.0, maximum=2.0, step=0.01, value=1.0)
+ shared_score_shift = gr.Slider(label="Shift", minimum=0, maximum=10, step=1, value=0)
+ only_self_level = gr.Slider(label="Level", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
+ with gr.Row():
+ prompt = gr.Textbox(lines=1, label='Optional image description', placeholder='use the style from the image')
+ with gr.Row():
+ image = gr.Image(label='Optional image', source='upload', type='pil')
+
+ image.change(self.reset)
+ preset.change(self.preset, inputs=[preset], outputs=[shared_opts, shared_score_scale, shared_score_shift, only_self_level])
+
+ return [image, prompt, scheduler, shared_opts, shared_score_scale, shared_score_shift, only_self_level]
+
+ def run(self, p: processing.StableDiffusionProcessing, image, prompt, scheduler, shared_opts, shared_score_scale, shared_score_shift, only_self_level): # pylint: disable=arguments-differ
+ global handler, zts, orig_prompt_attention # pylint: disable=global-statement
+ if shared.sd_model_type not in supported_model_list:
+ shared.log.warning(f'SA: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}')
+ return None
+
+ from modules.style_aligned import sa_handler, inversion
+
+ handler = sa_handler.Handler(shared.sd_model)
+ sa_args = sa_handler.StyleAlignedArgs(
+ share_group_norm='group_norm' in shared_opts,
+ share_layer_norm='layer_norm' in shared_opts,
+ share_attention='attention' in shared_opts,
+ adain_queries='adain_queries' in shared_opts,
+ adain_keys='adain_keys' in shared_opts,
+ adain_values='adain_values' in shared_opts,
+ full_attention_share='full_attention_share' in shared_opts,
+ shared_score_scale=float(shared_score_scale),
+ shared_score_shift=np.log(shared_score_shift) if shared_score_shift > 0 else 0,
+ only_self_level=1 if only_self_level else 0,
+ )
+ handler.register(sa_args)
+
+ if scheduler:
+ shared.sd_model.scheduler = diffusers.DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+ p.sampler_name = 'None'
+
+ if image is not None and zts is None:
+ shared.log.info(f'SA: inversion image={image} prompt="{prompt}"')
+ image = image.resize((1024, 1024))
+ x0 = np.array(image).astype(np.float32) / 255.0
+ shared.sd_model.scheduler = diffusers.DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+ zts = inversion.ddim_inversion(shared.sd_model, x0, prompt, num_inference_steps=50, guidance_scale=2)
+
+ p.prompt = p.prompt.splitlines()
+ p.batch_size = len(p.prompt)
+ orig_prompt_attention = shared.opts.prompt_attention
+ shared.opts.data['prompt_attention'] = 'fixed' # otherwise need to deal with class_tokens_mask
+
+ if zts is not None:
+ processing.fix_seed(p)
+ zT, inversion_callback = inversion.make_inversion_callback(zts, offset=0)
+ generator = torch.Generator(device='cpu')
+ generator.manual_seed(p.seed)
+ latents = torch.randn(p.batch_size, 4, 128, 128, device='cpu', generator=generator, dtype=devices.dtype,).to(devices.device)
+ latents[0] = zT
+ p.task_args['latents'] = latents
+ p.task_args['callback_on_step_end'] = inversion_callback
+
+ shared.log.info(f'SA: batch={p.batch_size} type={"image" if zts is not None else "text"} config={sa_args.__dict__}')
+
+ def after(self, p: processing.StableDiffusionProcessing, *args): # pylint: disable=unused-argument
+ global handler # pylint: disable=global-statement
+ if handler is not None:
+ handler.remove()
+ handler = None
+ shared.opts.data['prompt_attention'] = orig_prompt_attention
diff --git a/scripts/text2video.py b/scripts/text2video.py
index dc4c44cac..c7b3d1c05 100644
--- a/scripts/text2video.py
+++ b/scripts/text2video.py
@@ -87,7 +87,7 @@ def run(self, p: processing.StableDiffusionProcessing, model_name, use_default,
shared.opts.sd_model_checkpoint = checkpoint.name
sd_models.reload_model_weights(op='model')
- p.ops.append('text2video')
+ p.ops.append('video')
p.do_not_save_grid = True
if use_default:
p.task_args['num_frames'] = model['params'][0]
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 60e608c76..bb067ea21 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -12,6 +12,7 @@
from scripts.xyz_grid_shared import str_permutations, list_to_csv_string, re_range # pylint: disable=no-name-in-module
from scripts.xyz_grid_classes import axis_options, AxisOption, SharedSettingsStackHelper # pylint: disable=no-name-in-module
from scripts.xyz_grid_draw import draw_xyz_grid # pylint: disable=no-name-in-module
+from scripts.xyz_grid_shared import apply_field, apply_task_args, apply_setting, apply_prompt, apply_order, apply_sampler, apply_hr_sampler_name, confirm_samplers, apply_checkpoint, apply_refiner, apply_unet, apply_dict, apply_clip_skip, apply_vae, list_lora, apply_lora, apply_lora_strength, apply_te, apply_styles, apply_upscaler, apply_context, apply_detailer, apply_override, apply_processing, apply_options, apply_seed, format_value_add_label, format_value, format_value_join_list, do_nothing, format_nothing # pylint: disable=no-name-in-module, unused-import
from modules import shared, errors, scripts, images, processing
from modules.ui_components import ToolButton
import modules.ui_symbols as symbols
diff --git a/scripts/xyz_grid_classes.py b/scripts/xyz_grid_classes.py
index b80b9f13c..cc70d68f8 100644
--- a/scripts/xyz_grid_classes.py
+++ b/scripts/xyz_grid_classes.py
@@ -1,4 +1,4 @@
-from scripts.xyz_grid_shared import apply_field, apply_task_args, apply_setting, apply_prompt, apply_order, apply_sampler, apply_hr_sampler_name, confirm_samplers, apply_checkpoint, apply_refiner, apply_unet, apply_dict, apply_clip_skip, apply_vae, list_lora, apply_lora, apply_te, apply_styles, apply_upscaler, apply_context, apply_detailer, apply_override, apply_processing, apply_options, apply_seed, format_value_add_label, format_value, format_value_join_list, do_nothing, format_nothing, str_permutations # pylint: disable=no-name-in-module, unused-import
+from scripts.xyz_grid_shared import apply_field, apply_task_args, apply_setting, apply_prompt, apply_order, apply_sampler, apply_hr_sampler_name, confirm_samplers, apply_checkpoint, apply_refiner, apply_unet, apply_dict, apply_clip_skip, apply_vae, list_lora, apply_lora, apply_lora_strength, apply_te, apply_styles, apply_upscaler, apply_context, apply_detailer, apply_override, apply_processing, apply_options, apply_seed, format_value_add_label, format_value, format_value_join_list, do_nothing, format_nothing, str_permutations # pylint: disable=no-name-in-module, unused-import
from modules import shared, shared_items, sd_samplers, ipadapter, sd_models, sd_vae, sd_unet
@@ -97,7 +97,7 @@ def __exit__(self, exc_type, exc_value, tb):
AxisOption("[Prompt] Prompt order", str_permutations, apply_order, fmt=format_value_join_list),
AxisOption("[Prompt] Prompt parser", str, apply_setting("prompt_attention"), choices=lambda: ["native", "compel", "xhinker", "a1111", "fixed"]),
AxisOption("[Network] LoRA", str, apply_lora, cost=0.5, choices=list_lora),
- AxisOption("[Network] LoRA strength", float, apply_setting('extra_networks_default_multiplier')),
+ AxisOption("[Network] LoRA strength", float, apply_lora_strength, cost=0.6),
AxisOption("[Network] Styles", str, apply_styles, choices=lambda: [s.name for s in shared.prompt_styles.styles.values()]),
AxisOption("[Param] Width", int, apply_field("width")),
AxisOption("[Param] Height", int, apply_field("height")),
diff --git a/scripts/xyz_grid_on.py b/scripts/xyz_grid_on.py
index 202a2cfc4..aa0897442 100644
--- a/scripts/xyz_grid_on.py
+++ b/scripts/xyz_grid_on.py
@@ -413,6 +413,7 @@ def cell(x, y, z, ix, iy, iz):
p.do_not_save_grid = True
p.do_not_save_samples = True
+ p.disable_extra_networks = True
active = False
cache = processed
return processed
diff --git a/scripts/xyz_grid_shared.py b/scripts/xyz_grid_shared.py
index d3ee0a864..f9bf26c67 100644
--- a/scripts/xyz_grid_shared.py
+++ b/scripts/xyz_grid_shared.py
@@ -63,28 +63,15 @@ def apply_seed(p, x, xs):
def apply_prompt(p, x, xs):
- if not hasattr(p, 'orig_prompt'):
- p.orig_prompt = p.prompt
- p.orig_negative = p.negative_prompt
- if xs[0] not in p.orig_prompt and xs[0] not in p.orig_negative:
- shared.log.warning(f'XYZ grid: prompt S/R string="{xs[0]}" not found')
- else:
- p.prompt = p.orig_prompt.replace(xs[0], x)
- p.negative_prompt = p.orig_negative.replace(xs[0], x)
- p.all_prompts = None
- p.all_negative_prompts = None
- """
- if p.all_prompts is not None:
- for i in range(len(p.all_prompts)):
- for j in range(len(xs)):
- p.all_prompts[i] = p.all_prompts[i].replace(xs[j], x)
- p.negative_prompt = p.negative_prompt.replace(xs[0], x)
- if p.all_negative_prompts is not None:
- for i in range(len(p.all_negative_prompts)):
- for j in range(len(xs)):
- p.all_negative_prompts[i] = p.all_negative_prompts[i].replace(xs[j], x)
- """
- shared.log.debug(f'XYZ grid apply prompt: "{xs[0]}"="{x}"')
+ for s in xs:
+ if s in p.prompt:
+ shared.log.debug(f'XYZ grid apply prompt: "{s}"="{x}"')
+ p.prompt = p.prompt.replace(s, x)
+ if s in p.negative_prompt:
+ shared.log.debug(f'XYZ grid apply negative: "{s}"="{x}"')
+ p.negative_prompt = p.negative_prompt.replace(s, x)
+ p.all_prompts = None
+ p.all_negative_prompts = None
def apply_order(p, x, xs):
@@ -205,19 +192,28 @@ def apply_vae(p, x, xs):
def list_lora():
import sys
- lora = [v for k, v in sys.modules.items() if k == 'networks'][0]
+ lora = [v for k, v in sys.modules.items() if k == 'networks' or k == 'modules.lora.networks'][0]
loras = [v.fullname for v in lora.available_networks.values()]
return ['None'] + loras
def apply_lora(p, x, xs):
+ p.all_prompts = None
+ p.all_negative_prompts = None
if x == 'None':
return
x = os.path.basename(x)
p.prompt = p.prompt + f" "
+ shared.log.debug(f'XYZ grid apply LoRA: "{x}"')
+
+
+def apply_lora_strength(p, x, xs):
+ shared.log.debug(f'XYZ grid apply LoRA strength: "{x}"')
+ p.prompt = p.prompt.replace(':1.0>', '>')
+ p.prompt = p.prompt.replace(f':{shared.opts.extra_networks_default_multiplier}>', '>')
p.all_prompts = None
p.all_negative_prompts = None
- shared.log.debug(f'XYZ grid apply LoRA: "{x}"')
+ shared.opts.data['extra_networks_default_multiplier'] = x
def apply_te(p, x, xs):
diff --git a/webui.py b/webui.py
index 9684ef7c8..4eb6e89ce 100644
--- a/webui.py
+++ b/webui.py
@@ -8,6 +8,7 @@
import importlib
import contextlib
from threading import Thread
+import modules.hashes
import modules.loader
import torch # pylint: disable=wrong-import-order
from modules import timer, errors, paths # pylint: disable=unused-import
@@ -18,6 +19,7 @@
from modules.paths import create_paths
from modules.call_queue import queue_lock, wrap_queued_call, wrap_gradio_gpu_call # pylint: disable=unused-import
import modules.devices
+import modules.sd_checkpoint
import modules.sd_samplers
import modules.lowvram
import modules.scripts
@@ -77,6 +79,9 @@ def check_rollback_vae():
def initialize():
log.debug('Initializing')
+
+ modules.sd_checkpoint.init_metadata()
+ modules.hashes.init_cache()
check_rollback_vae()
modules.sd_samplers.list_samplers()
@@ -89,15 +94,20 @@ def initialize():
timer.startup.record("unet")
modules.model_te.refresh_te_list()
- timer.startup.record("unet")
-
- extensions.list_extensions()
- timer.startup.record("extensions")
+ timer.startup.record("te")
modelloader.cleanup_models()
modules.sd_models.setup_model()
timer.startup.record("models")
+ if shared.native:
+ import modules.lora.networks as lora_networks
+ lora_networks.list_available_networks()
+ timer.startup.record("lora")
+
+ shared.prompt_styles.reload()
+ timer.startup.record("styles")
+
import modules.postprocess.codeformer_model as codeformer
codeformer.setup_model(shared.opts.codeformer_models_path)
sys.modules["modules.codeformer_model"] = codeformer
@@ -107,6 +117,9 @@ def initialize():
yolo.initialize()
timer.startup.record("detailer")
+ extensions.list_extensions()
+ timer.startup.record("extensions")
+
log.info('Load extensions')
t_timer, t_total = modules.scripts.load_scripts()
timer.startup.record("extensions")
@@ -116,8 +129,9 @@ def initialize():
modelloader.load_upscalers()
timer.startup.record("upscalers")
- shared.reload_hypernetworks()
- shared.prompt_styles.reload()
+ if shared.opts.hypernetwork_enabled:
+ shared.reload_hypernetworks()
+ timer.startup.record("hypernetworks")
ui_extra_networks.initialize()
ui_extra_networks.register_pages()
diff --git a/wiki b/wiki
index 30f3265bb..56ba782f7 160000
--- a/wiki
+++ b/wiki
@@ -1 +1 @@
-Subproject commit 30f3265bb06ac738e4467f58be4df3fc4b49c08b
+Subproject commit 56ba782f744bb8f6928f6c365d6ffc547d339548