From cca81bf296c07e6864e4dc0085a2601dcc91584b Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 18 Jan 2025 12:35:31 +0100 Subject: [PATCH] update --- docs/source/en/using-diffusers/consisid.md | 3 +- src/diffusers/loaders/__init__.py | 2 - src/diffusers/loaders/lora_pipeline.py | 307 ------------------ .../transformers/consisid_transformer_3d.py | 231 ++++--------- .../pipelines/consisid/pipeline_consisid.py | 9 +- 5 files changed, 62 insertions(+), 490 deletions(-) diff --git a/docs/source/en/using-diffusers/consisid.md b/docs/source/en/using-diffusers/consisid.md index 2b17575dc962..07c13c4c66b3 100644 --- a/docs/source/en/using-diffusers/consisid.md +++ b/docs/source/en/using-diffusers/consisid.md @@ -20,8 +20,8 @@ specific language governing permissions and limitations under the License. This guide will walk you through using ConsisID for use cases. ## Load Model Checkpoints -Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method. +Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method. ```python # !pip install consisid_eva_clip insightface facexlib @@ -42,6 +42,7 @@ pipe.to("cuda") ``` ## Identity-Preserving Text-to-Video + For identity-preserving text-to-video, pass a text prompt and an image contain clear face (e.g., preferably half-body or full-body). By default, ConsisID generates a 720x480 video for the best results. ```python diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 87fb52f16743..2db8b53db498 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -70,7 +70,6 @@ def text_encoder_attn_modules(text_encoder): "LoraLoaderMixin", "FluxLoraLoaderMixin", "CogVideoXLoraLoaderMixin", - "ConsisIDLoraLoaderMixin", "Mochi1LoraLoaderMixin", "HunyuanVideoLoraLoaderMixin", "SanaLoraLoaderMixin", @@ -102,7 +101,6 @@ def text_encoder_attn_modules(text_encoder): from .lora_pipeline import ( AmusedLoraLoaderMixin, CogVideoXLoraLoaderMixin, - ConsisIDLoraLoaderMixin, FluxLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index d4047b33c77a..efefe5264daa 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2590,313 +2590,6 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components) -class ConsisIDLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`ConsisIDTransformer3DModel`]. Specific to [`ConsisIDPipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - return state_dict - - def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->ConsisIDTransformer3DModel - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`ConsisIDTransformer3DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - @classmethod - # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer - def fuse_lora( - self, - components: List[str] = ["transformer", "text_encoder"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names - ) - - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer - def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. - """ - super().unfuse_lora(components=components) - - class Mochi1LoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`]. diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py index 8ca7eec780ea..86a6628b5161 100644 --- a/src/diffusers/models/transformers/consisid_transformer_3d.py +++ b/src/diffusers/models/transformers/consisid_transformer_3d.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -33,61 +33,10 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def reshape_tensor(x, heads): - """ - Reshapes the input tensor for multi-head attention. - - Args: - x (torch.Tensor): The input tensor with shape (batch_size, length, width). - heads (int): The number of attention heads. - - Returns: - torch.Tensor: The reshaped tensor, with shape (batch_size, heads, length, width). - """ - bs, length, width = x.shape - x = x.view(bs, length, heads, -1) - x = x.transpose(1, 2) - x = x.reshape(bs, heads, length, -1) - return x - - -def ConsisIDFeedForward(dim, mult=4): - """ - Creates a consistent ID feedforward block consisting of layer normalization, two linear layers, and a GELU - activation. - - Args: - dim (int): The input dimension of the tensor. - mult (int, optional): Multiplier for the inner dimension. Default is 4. - - Returns: - nn.Sequential: A sequence of layers comprising LayerNorm, Linear layers, and GELU. - """ - inner_dim = int(dim * mult) - return nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, inner_dim, bias=False), - nn.GELU(), - nn.Linear(inner_dim, dim, bias=False), - ) - - class PerceiverAttention(nn.Module): - """ - Implements the Perceiver attention mechanism with multi-head attention. - - This layer takes two inputs: 'x' (image features) and 'latents' (latent features), applying multi-head attention to - both and producing an output tensor with the same dimension as the input tensor 'x'. - - Args: - dim (int): The input dimension. - dim_head (int, optional): The dimension of each attention head. Default is 64. - heads (int, optional): The number of attention heads. Default is 8. - kv_dim (int, optional): The key-value dimension. If None, `dim` is used for both keys and values. - """ - - def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None): + def __init__(self, dim: int, dim_head: int = 64, heads: int = 8, kv_dim: Optional[int] = None): super().__init__() + self.scale = dim_head**-0.5 self.dim_head = dim_head self.heads = heads @@ -100,74 +49,49 @@ def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None): self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) - def forward(self, x, latents): - """ - Forward pass for Perceiver attention. - - Args: - x (torch.Tensor): Image features tensor with shape (batch_size, num_pixels, D). - latents (torch.Tensor): Latent features tensor with shape (batch_size, num_latents, D). - - Returns: - torch.Tensor: Output tensor after applying attention and transformation. - """ + def forward(self, image_embeds: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: # Apply normalization - x = self.norm1(x) + image_embeds = self.norm1(image_embeds) latents = self.norm2(latents) - b, seq_len, _ = latents.shape # Get batch size and sequence length + batch_size, seq_len, _ = latents.shape # Get batch size and sequence length # Compute query, key, and value matrices - q = self.to_q(latents) - kv_input = torch.cat((x, latents), dim=-2) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) + query = self.to_q(latents) + kv_input = torch.cat((image_embeds, latents), dim=-2) + key, value = self.to_kv(kv_input).chunk(2, dim=-1) # Reshape the tensors for multi-head attention - q = reshape_tensor(q, self.heads) - k = reshape_tensor(k, self.heads) - v = reshape_tensor(v, self.heads) + query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2) + key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2) + value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2) # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - out = weight @ v + output = weight @ value # Reshape and return the final output - out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1) + output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1) - return self.to_out(out) + return self.to_out(output) class LocalFacialExtractor(nn.Module): def __init__( self, - id_dim=1280, - vit_dim=1024, - depth=10, - dim_head=64, - heads=16, - num_id_token=5, - num_queries=32, - output_dim=2048, - ff_mult=4, - num_scale=5, + id_dim: int = 1280, + vit_dim: int = 1024, + depth: int = 10, + dim_head: int = 64, + heads: int = 16, + num_id_token: int = 5, + num_queries: int = 32, + output_dim: int = 2048, + ff_mult: int = 4, + num_scale: int = 5, ): - """ - Initializes the LocalFacialExtractor class. - - Parameters: - - id_dim (int): The dimensionality of id features. - - vit_dim (int): The dimensionality of vit features. - - depth (int): Total number of PerceiverAttention and ConsisIDFeedForward layers. - - dim_head (int): Dimensionality of each attention head. - - heads (int): Number of attention heads. - - num_id_token (int): Number of tokens used for identity features. - - num_queries (int): Number of query tokens for the latent representation. - - output_dim (int): Output dimension after projection. - - ff_mult (int): Multiplier for the feed-forward network hidden dimension. - - num_scale (int): The number of different scales visual feature. - """ super().__init__() # Storing identity token and query information @@ -191,7 +115,12 @@ def __init__( nn.ModuleList( [ PerceiverAttention(dim=vit_dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer - ConsisIDFeedForward(dim=vit_dim, mult=ff_mult), # ConsisIDFeedForward layer + nn.Sequential( + nn.LayerNorm(vit_dim), + nn.Linear(vit_dim, vit_dim * ff_mult, bias=False), + nn.GELU(), + nn.Linear(vit_dim * ff_mult, vit_dim, bias=False), + ), # ConsisIDFeedForward layer ] ) ) @@ -223,32 +152,21 @@ def __init__( nn.Linear(vit_dim, vit_dim * num_id_token), ) - def forward(self, x, y): - """ - Forward pass for LocalFacialExtractor. - - Parameters: - - x (Tensor): The input identity embedding tensor of shape (batch_size, id_dim). - - y (list of Tensor): A list of 5 visual feature tensors each of shape (batch_size, vit_dim). - - Returns: - - Tensor: The extracted latent features of shape (batch_size, num_queries, output_dim). - """ - + def forward(self, id_embeds: torch.Tensor, vit_hidden_states: List[torch.Tensor]) -> torch.Tensor: # Repeat latent queries for the batch size - latents = self.latents.repeat(x.size(0), 1, 1) + latents = self.latents.repeat(id_embeds.size(0), 1, 1) # Map the identity embedding to tokens - x = self.id_embedding_mapping(x) - x = x.reshape(-1, self.num_id_token, self.vit_dim) + id_embeds = self.id_embedding_mapping(id_embeds) + id_embeds = id_embeds.reshape(-1, self.num_id_token, self.vit_dim) # Concatenate identity tokens with the latent queries - latents = torch.cat((latents, x), dim=1) + latents = torch.cat((latents, id_embeds), dim=1) # Process each of the num_scale visual feature inputs for i in range(self.num_scale): - vit_feature = getattr(self, f"mapping_{i}")(y[i]) - ctx_feature = torch.cat((x, vit_feature), dim=1) + vit_feature = getattr(self, f"mapping_{i}")(vit_hidden_states[i]) + ctx_feature = torch.cat((id_embeds, vit_feature), dim=1) # Pass through the PerceiverAttention and ConsisIDFeedForward layers for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]: @@ -263,26 +181,9 @@ def forward(self, x, y): class PerceiverCrossAttention(nn.Module): - """ - - Args: - dim (int): Dimension of the input latent and output. Default is 3072. - dim_head (int): Dimension of each attention head. Default is 128. - heads (int): Number of attention heads. Default is 16. - kv_dim (int): Dimension of the key/value input, allowing flexible cross-attention. Default is 2048. - - Attributes: - scale (float): Scaling factor used in dot-product attention for numerical stability. - norm1 (nn.LayerNorm): Layer normalization applied to the input image features. - norm2 (nn.LayerNorm): Layer normalization applied to the latent features. - to_q (nn.Linear): Linear layer for projecting the latent features into queries. - to_kv (nn.Linear): Linear layer for projecting the input features into keys and values. - to_out (nn.Linear): Linear layer for outputting the final result after attention. - - """ - - def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048): + def __init__(self, dim: int = 3072, dim_head: int = 128, heads: int = 16, kv_dim: int = 2048): super().__init__() + self.scale = dim_head**-0.5 self.dim_head = dim_head self.heads = heads @@ -297,47 +198,32 @@ def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048): self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) - def forward(self, x, latents): - """ - - Args: - x (torch.Tensor): Input image features with shape (batch_size, n1, D), where: - - batch_size (b): Number of samples in the batch. - - n1: Sequence length (e.g., number of patches or tokens). - - D: Feature dimension. - - latents (torch.Tensor): Latent feature representations with shape (batch_size, n2, D), where: - - n2: Number of latent elements. - - Returns: - torch.Tensor: Attention-modulated features with shape (batch_size, n2, D). - - """ + def forward(self, image_embeds: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor: # Apply layer normalization to the input image and latent features - x = self.norm1(x) - latents = self.norm2(latents) + image_embeds = self.norm1(image_embeds) + hidden_states = self.norm2(hidden_states) - b, seq_len, _ = latents.shape + batch_size, seq_len, _ = hidden_states.shape # Compute queries, keys, and values - q = self.to_q(latents) - k, v = self.to_kv(x).chunk(2, dim=-1) + query = self.to_q(hidden_states) + key, value = self.to_kv(image_embeds).chunk(2, dim=-1) # Reshape tensors to split into attention heads - q = reshape_tensor(q, self.heads) - k = reshape_tensor(k, self.heads) - v = reshape_tensor(v, self.heads) + query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2) + key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2) + value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2) # Compute attention weights scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable scaling than post-division + weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable scaling than post-division weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) # Compute the output via weighted combination of values - out = weight @ v + out = weight @ value # Reshape and permute to prepare for final linear transformation - out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1) + out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1) return self.to_out(out) @@ -680,8 +566,6 @@ def __init__( ) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) - self.gradient_checkpointing = False - self.is_train_face = is_train_face self.is_kps = is_kps @@ -709,12 +593,12 @@ def __init__( # face modules self._init_face_inputs() + self.gradient_checkpointing = False + def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value def _init_face_inputs(self): - device = self.device - weight_dtype = self.dtype self.local_facial_extractor = LocalFacialExtractor( id_dim=self.LFE_id_dim, vit_dim=self.LFE_vit_dim, @@ -727,7 +611,6 @@ def _init_face_inputs(self): ff_mult=self.LFE_ff_mult, num_scale=self.LFE_num_scale, ) - self.local_facial_extractor.to(device, dtype=weight_dtype) self.perceiver_cross_attention = nn.ModuleList( [ PerceiverCrossAttention( @@ -735,7 +618,7 @@ def _init_face_inputs(self): dim_head=self.cross_attn_dim_head, heads=self.cross_attn_num_heads, kv_dim=self.cross_attn_kv_dim, - ).to(device, dtype=weight_dtype) + ) for _ in range(self.num_cross_attn) ] ) @@ -828,8 +711,8 @@ def forward( ) # fuse clip and insightface + valid_face_emb = None if self.is_train_face: - assert id_cond is not None and id_vit_hidden is not None id_cond = id_cond.to(device=hidden_states.device, dtype=hidden_states.dtype) id_vit_hidden = [ tensor.to(device=hidden_states.device, dtype=hidden_states.dtype) for tensor in id_vit_hidden diff --git a/src/diffusers/pipelines/consisid/pipeline_consisid.py b/src/diffusers/pipelines/consisid/pipeline_consisid.py index 4e1b5cf08911..0d4891cf17d7 100644 --- a/src/diffusers/pipelines/consisid/pipeline_consisid.py +++ b/src/diffusers/pipelines/consisid/pipeline_consisid.py @@ -24,15 +24,12 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput -from ...loaders import ConsisIDLoraLoaderMixin +from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, ConsisIDTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDPMScheduler -from ...utils import ( - logging, - replace_example_docstring, -) +from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from .pipeline_output import ConsisIDPipelineOutput @@ -242,7 +239,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class ConsisIDPipeline(DiffusionPipeline, ConsisIDLoraLoaderMixin): +class ConsisIDPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): r""" Pipeline for image-to-video generation using ConsisID.