Skip to content

Commit

Permalink
sana: no CHI for negative prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Dec 17, 2024
1 parent b3cb108 commit d3500aa
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
13 changes: 10 additions & 3 deletions helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,14 +599,18 @@ def compute_t5_prompt(self, prompt: str):

return result, attn_mask

def compute_gemma_prompt(self, prompt: str):
def compute_gemma_prompt(self, prompt: str, is_negative_prompt: bool):
prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt(
prompt=prompt,
do_classifier_free_guidance=False,
device=self.accelerator.device,
clean_caption=False,
max_sequence_length=300,
complex_human_instruction=StateTracker.get_args().sana_complex_human_instruction,
complex_human_instruction=(
StateTracker.get_args().sana_complex_human_instruction
if not is_negative_prompt
else None
),
)

return prompt_embeds, prompt_attention_mask
Expand All @@ -617,6 +621,7 @@ def compute_embeddings_for_prompts(
return_concat: bool = True,
is_validation: bool = False,
load_from_cache: bool = True,
is_negative_prompt: bool = False,
):
logger.debug("Initialising text embed calculator...")
if not self.batch_write_thread.is_alive():
Expand Down Expand Up @@ -695,6 +700,7 @@ def compute_embeddings_for_prompts(
raw_prompts,
return_concat=return_concat,
load_from_cache=load_from_cache,
is_negative_prompt=is_negative_prompt,
)
else:
raise ValueError(
Expand Down Expand Up @@ -1074,6 +1080,7 @@ def compute_embeddings_for_sana_prompts(
prompts: list = None,
return_concat: bool = True,
load_from_cache: bool = True,
is_negative_prompt: bool = False,
):
logger.debug(
f"compute_embeddings_for_sana_prompts arguments: prompts={prompts}, return_concat={return_concat}, load_from_cache={load_from_cache}"
Expand Down Expand Up @@ -1172,7 +1179,7 @@ def compute_embeddings_for_sana_prompts(
time.sleep(5)
# TODO: Batch this
prompt_embeds, attention_mask = self.compute_gemma_prompt(
prompt=prompt,
prompt=prompt, is_negative_prompt=is_negative_prompt
)
if "deepfloyd" not in StateTracker.get_args().model_type:
# we have to store the attn mask with the embed for pixart.
Expand Down
7 changes: 6 additions & 1 deletion helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2561,6 +2561,7 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False):
if (
args.sana_complex_human_instruction is not None
and type(args.sana_complex_human_instruction) is str
and args.sana_complex_human_instruction not in ["", "None"]
):
try:
import json
Expand All @@ -2569,8 +2570,12 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False):
args.sana_complex_human_instruction
)
except Exception as e:
logger.error(f"Could not load complex human instruction: {e}")
logger.error(
f"Could not load complex human instruction ({args.sana_complex_human_instruction}): {e}"
)
raise
elif args.sana_complex_human_instruction == "None":
args.sana_complex_human_instruction = None

if args.enable_xformers_memory_efficient_attention:
if args.attention_mechanism != "xformers":
Expand Down
14 changes: 9 additions & 5 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,10 @@ def prepare_validation_prompt_list(args, embed_cache):
or model_type == "sana"
):
# we use the legacy encoder but we return no pooled embeds.
validation_negative_prompt_embeds = (
embed_cache.compute_embeddings_for_prompts(
[StateTracker.get_args().validation_negative_prompt],
load_from_cache=False,
)
validation_negative_prompt_embeds = embed_cache.compute_embeddings_for_prompts(
[StateTracker.get_args().validation_negative_prompt],
load_from_cache=False,
is_negative_prompt=True, # sana needs this to disable Complex Human Instruction on negative embed generation
)

return (
Expand Down Expand Up @@ -1388,6 +1387,11 @@ def validate_prompt(
if StateTracker.get_model_family() == "flux":
if "negative_prompt" in pipeline_kwargs:
del pipeline_kwargs["negative_prompt"]
if self.args.model_family == "sana":
pipeline_kwargs["complex_human_instruction"] = (
self.args.sana_complex_human_instruction
)

if (
StateTracker.get_model_family() == "pixart_sigma"
or StateTracker.get_model_family() == "smoldit"
Expand Down

0 comments on commit d3500aa

Please sign in to comment.