diff --git a/helpers/caching/text_embeds.py b/helpers/caching/text_embeds.py index 84994320..f35e9350 100644 --- a/helpers/caching/text_embeds.py +++ b/helpers/caching/text_embeds.py @@ -599,14 +599,18 @@ def compute_t5_prompt(self, prompt: str): return result, attn_mask - def compute_gemma_prompt(self, prompt: str): + def compute_gemma_prompt(self, prompt: str, is_negative_prompt: bool): prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt( prompt=prompt, do_classifier_free_guidance=False, device=self.accelerator.device, clean_caption=False, max_sequence_length=300, - complex_human_instruction=StateTracker.get_args().sana_complex_human_instruction, + complex_human_instruction=( + StateTracker.get_args().sana_complex_human_instruction + if not is_negative_prompt + else None + ), ) return prompt_embeds, prompt_attention_mask @@ -617,6 +621,7 @@ def compute_embeddings_for_prompts( return_concat: bool = True, is_validation: bool = False, load_from_cache: bool = True, + is_negative_prompt: bool = False, ): logger.debug("Initialising text embed calculator...") if not self.batch_write_thread.is_alive(): @@ -695,6 +700,7 @@ def compute_embeddings_for_prompts( raw_prompts, return_concat=return_concat, load_from_cache=load_from_cache, + is_negative_prompt=is_negative_prompt, ) else: raise ValueError( @@ -1074,6 +1080,7 @@ def compute_embeddings_for_sana_prompts( prompts: list = None, return_concat: bool = True, load_from_cache: bool = True, + is_negative_prompt: bool = False, ): logger.debug( f"compute_embeddings_for_sana_prompts arguments: prompts={prompts}, return_concat={return_concat}, load_from_cache={load_from_cache}" @@ -1172,7 +1179,7 @@ def compute_embeddings_for_sana_prompts( time.sleep(5) # TODO: Batch this prompt_embeds, attention_mask = self.compute_gemma_prompt( - prompt=prompt, + prompt=prompt, is_negative_prompt=is_negative_prompt ) if "deepfloyd" not in StateTracker.get_args().model_type: # we have to store the attn mask with the embed for pixart. diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index b3ed0631..89ff22f0 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -2561,6 +2561,7 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False): if ( args.sana_complex_human_instruction is not None and type(args.sana_complex_human_instruction) is str + and args.sana_complex_human_instruction not in ["", "None"] ): try: import json @@ -2569,8 +2570,12 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False): args.sana_complex_human_instruction ) except Exception as e: - logger.error(f"Could not load complex human instruction: {e}") + logger.error( + f"Could not load complex human instruction ({args.sana_complex_human_instruction}): {e}" + ) raise + elif args.sana_complex_human_instruction == "None": + args.sana_complex_human_instruction = None if args.enable_xformers_memory_efficient_attention: if args.attention_mechanism != "xformers": diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 8055e452..de8cf7cf 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -258,11 +258,10 @@ def prepare_validation_prompt_list(args, embed_cache): or model_type == "sana" ): # we use the legacy encoder but we return no pooled embeds. - validation_negative_prompt_embeds = ( - embed_cache.compute_embeddings_for_prompts( - [StateTracker.get_args().validation_negative_prompt], - load_from_cache=False, - ) + validation_negative_prompt_embeds = embed_cache.compute_embeddings_for_prompts( + [StateTracker.get_args().validation_negative_prompt], + load_from_cache=False, + is_negative_prompt=True, # sana needs this to disable Complex Human Instruction on negative embed generation ) return ( @@ -1388,6 +1387,11 @@ def validate_prompt( if StateTracker.get_model_family() == "flux": if "negative_prompt" in pipeline_kwargs: del pipeline_kwargs["negative_prompt"] + if self.args.model_family == "sana": + pipeline_kwargs["complex_human_instruction"] = ( + self.args.sana_complex_human_instruction + ) + if ( StateTracker.get_model_family() == "pixart_sigma" or StateTracker.get_model_family() == "smoldit"