Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding hunyuan hf (support lora finetuning); unified hunyuan hf inference with quantization #135

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a133b58
fix lora cp saving issue
BrianChen1129 Dec 19, 2024
3008adf
Merge branch 'main' of github.com:jzhang38/FastVideo-OSP
BrianChen1129 Dec 19, 2024
a3f8fc2
fix lora save issue
BrianChen1129 Dec 19, 2024
580575c
fix lora save issue
BrianChen1129 Dec 19, 2024
4c16991
Revert "fix lora save issue"
BrianChen1129 Dec 19, 2024
677efd9
fix lora save issue
BrianChen1129 Dec 19, 2024
48fd2d2
Merge branch 'main' of github.com:jzhang38/FastVideo-OSP into yq-lora…
BrianChen1129 Dec 20, 2024
c31f507
debug hunyuan hf sp
BrianChen1129 Dec 23, 2024
ac58b2f
test hunyuan hf
BrianChen1129 Dec 24, 2024
77b690c
add huanyuan hf inference and train
BrianChen1129 Dec 24, 2024
04c1610
support hunyuan hf lora
BrianChen1129 Dec 27, 2024
bfe0448
syn with main
BrianChen1129 Dec 27, 2024
594b65a
syn with main
BrianChen1129 Dec 27, 2024
36560eb
unified hunyuan hf
BrianChen1129 Dec 29, 2024
d25c235
unified hunyuan hf
BrianChen1129 Dec 29, 2024
557dbca
unified hunyuan hf
BrianChen1129 Dec 29, 2024
0c058c3
add lora
BrianChen1129 Jan 7, 2025
cbc52b0
syn with main
BrianChen1129 Jan 7, 2025
f35ea70
unify hunyuan hf inference
BrianChen1129 Jan 7, 2025
cad6e6d
unify hunyuan hf inference
BrianChen1129 Jan 7, 2025
1832042
syn with main
BrianChen1129 Jan 7, 2025
5d21b16
syn
BrianChen1129 Jan 7, 2025
1d7d637
syn
BrianChen1129 Jan 7, 2025
c1cf441
syn
BrianChen1129 Jan 7, 2025
22d499b
syn
BrianChen1129 Jan 7, 2025
e7ea0d7
syn
BrianChen1129 Jan 7, 2025
afe24e2
syn
BrianChen1129 Jan 7, 2025
41f8ac9
syn
BrianChen1129 Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
952 changes: 952 additions & 0 deletions fastvideo/models/hunyuan_hf/modeling_hunyuan.py

Large diffs are not rendered by default.

756 changes: 756 additions & 0 deletions fastvideo/models/hunyuan_hf/pipeline_hunyuan.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions fastvideo/models/mochi_hf/mochi_latents_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def normalize_dit_input(model_type, latents):
latents_std = mochi_latents_std.to(latents.device, latents.dtype)
latents = (latents - latents_mean) / latents_std
return latents
elif model_type == "hunyuan_hf":
return latents * 0.476986
elif model_type == "hunyuan":
return latents * 0.476986
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,131 @@
import argparse
import torch
import torch.distributed as dist
from diffusers import BitsAndBytesConfig
from diffusers.utils import export_to_video
import imageio as iio
import math
import numpy as np
import io
import os
import time
import argparse
import os
import json
from fastvideo.utils.parallel_states import (
initialize_sequence_parallel_state,
nccl_info,
)
from fastvideo.models.hunyuan_hf.pipeline_hunyuan import HunyuanVideoPipeline
from fastvideo.models.hunyuan_hf.modeling_hunyuan import HunyuanVideoTransformer3DModel

import imageio as iio
import numpy as np
import torch
from diffusers import (BitsAndBytesConfig, HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel)
def initialize_distributed():
os.environ["TOKENIZERS_PARALLELISM"] = "false"
local_rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
print("world_size", world_size)
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl", init_method="env://", world_size=world_size, rank=local_rank
)
initialize_sequence_parallel_state(world_size)

def inference(args):
initialize_distributed()
print(nccl_info.sp_size)
device = torch.cuda.current_device()
# Peiyuan: GPU seed will cause A100 and H100 to produce different results .....
weight_dtype = torch.bfloat16

if args.transformer_path is not None:
transformer = HunyuanVideoTransformer3DModel.from_pretrained(args.transformer_path)
else:
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
args.model_path, subfolder="transformer/", torch_dtype=weight_dtype
)

pipe = HunyuanVideoPipeline.from_pretrained(
args.model_path, transformer=transformer, torch_dtype=weight_dtype
)

def export_to_video_bytes(fps, frames):
request = iio.core.Request("<bytes>", mode="w", extension=".mp4")
pyavobject = iio.plugins.pyav.PyAVPlugin(request)
if isinstance(frames, np.ndarray):
frames = (np.array(frames) * 255).astype('uint8')
pipe.enable_vae_tiling()

if args.lora_checkpoint_dir is not None:
print(f"Loading LoRA weights from {args.lora_checkpoint_dir}")
config_path = os.path.join(args.lora_checkpoint_dir, "lora_config.json")
with open(config_path, "r") as f:
lora_config_dict = json.load(f)
rank = lora_config_dict["lora_params"]["lora_rank"]
lora_alpha = lora_config_dict["lora_params"]["lora_alpha"]
lora_scaling = lora_alpha / rank
pipe.load_lora_weights(args.lora_checkpoint_dir, adapter_name="default")
pipe.set_adapters(["default"], [lora_scaling])
print(f"Successfully Loaded LoRA weights from {args.lora_checkpoint_dir}")
if args.cpu_offload:
pipe.enable_model_cpu_offload(device)
else:
frames = np.array(frames)
new_bytes = pyavobject.write(frames, codec="libx264", fps=fps)
out_bytes = io.BytesIO(new_bytes)
return out_bytes
pipe.to(device)

# Generate videos from the input prompt

def export_to_video(frames, path, fps):
video_bytes = export_to_video_bytes(fps, frames)
video_bytes.seek(0)
with open(path, "wb") as f:
f.write(video_bytes.getbuffer())
if args.prompt_embed_path is not None:
prompt_embeds = (
torch.load(args.prompt_embed_path, map_location="cpu", weights_only=True)
.to(device)
.unsqueeze(0)
)
encoder_attention_mask = (
torch.load(
args.encoder_attention_mask_path, map_location="cpu", weights_only=True
)
.to(device)
.unsqueeze(0)
)
prompts = None
elif args.prompt_path is not None:
prompts = [line.strip() for line in open(args.prompt_path, "r")]
prompt_embeds = None
encoder_attention_mask = None
else:
prompts = args.prompts
prompt_embeds = None
encoder_attention_mask = None

if prompts is not None:
with torch.autocast("cuda", dtype=torch.bfloat16):
for prompt in prompts:
generator = torch.Generator("cpu").manual_seed(args.seed)
video = pipe(
prompt=[prompt],
height=args.height,
width=args.width,
num_frames=args.num_frames,
num_inference_steps=args.num_inference_steps,
generator=generator,
).frames
if nccl_info.global_rank <= 0:
os.makedirs(args.output_path, exist_ok=True)
suffix = prompt.split(".")[0]
export_to_video(
video[0],
os.path.join(args.output_path, f"{suffix}.mp4"),
fps=24,
)
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
generator = torch.Generator("cpu").manual_seed(args.seed)
videos = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=encoder_attention_mask,
height=args.height,
width=args.width,
num_frames=args.num_frames,
num_inference_steps=args.num_inference_steps,
generator=generator,
).frames

if nccl_info.global_rank <= 0:
export_to_video(videos[0], args.output_path + ".mp4", fps=24)

def main(args):
def inference_quantization(args):
torch.manual_seed(args.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
prompt_template = {
Expand All @@ -42,10 +138,8 @@ def main(args):
"5. Camera angles, movements, and transitions used in the video."
"6. Thematic and aesthetic concepts associated with the scene, i.e. realistic, futuristic, fairy tale, etc<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"),
"crop_start":
95,
"crop_start":95,
}

model_id = args.model_path

if args.quantization == "nf4":
Expand Down Expand Up @@ -119,36 +213,40 @@ def main(args):
round(torch.cuda.max_memory_allocated(device="cuda") / 1024**3, 3),
"GiB")


if __name__ == "__main__":
parser = argparse.ArgumentParser()

# Basic parameters
parser.add_argument("--prompt", type=str, help="prompt file for inference")
parser.add_argument("--prompt_embed_path", type=str, default=None)
parser.add_argument("--prompt_path", type=str, default=None)
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--height", type=int, default=256)
parser.add_argument("--width", type=int, default=256)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_path", type=str, default="data/hunyuan")
parser.add_argument("--transformer_path", type=str, default=None)
parser.add_argument("--output_path", type=str, default="./outputs/video")
parser.add_argument("--fps", type=int, default=24)
parser.add_argument("--quantization", type=str, default=None)
parser.add_argument("--cpu_offload", action="store_true")
parser.add_argument(
"--lora_checkpoint_dir",
type=str,
default=None,
help="Path to the directory containing LoRA checkpoints",
)
# Additional parameters
parser.add_argument(
"--denoise-type",
type=str,
default="flow",
help="Denoise type for noised inputs.",
)
parser.add_argument("--seed",
type=int,
default=None,
help="Seed for evaluation.")
parser.add_argument("--neg_prompt",
type=str,
default=None,
help="Negative prompt for sampling.")
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
parser.add_argument(
"--neg_prompt", type=str, default=None, help="Negative prompt for sampling."
)
parser.add_argument(
"--guidance_scale",
type=float,
Expand All @@ -161,14 +259,12 @@ def main(args):
default=6.0,
help="Embedded classifier free guidance scale.",
)
parser.add_argument("--flow_shift",
type=int,
default=7,
help="Flow shift parameter.")
parser.add_argument("--batch_size",
type=int,
default=1,
help="Batch size for inference.")
parser.add_argument(
"--flow_shift", type=int, default=7, help="Flow shift parameter."
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for inference."
)
parser.add_argument(
"--num_videos",
type=int,
Expand All @@ -179,8 +275,7 @@ def main(args):
"--load-key",
type=str,
default="module",
help=
"Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
)
parser.add_argument(
"--use-cpu-offload",
Expand All @@ -190,20 +285,17 @@ def main(args):
parser.add_argument(
"--dit-weight",
type=str,
default=
"data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
default="data/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
)
parser.add_argument(
"--reproduce",
action="store_true",
help=
"Enable reproducibility by setting random seeds and deterministic algorithms.",
help="Enable reproducibility by setting random seeds and deterministic algorithms.",
)
parser.add_argument(
"--disable-autocast",
action="store_true",
help=
"Disable autocast for denoising loop and vae decoding in pipeline sampling.",
help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
)

# Flow Matching
Expand All @@ -212,15 +304,13 @@ def main(args):
action="store_true",
help="If reverse, learning/sampling from t=1 -> t=0.",
)
parser.add_argument("--flow-solver",
type=str,
default="euler",
help="Solver for flow matching.")
parser.add_argument(
"--flow-solver", type=str, default="euler", help="Solver for flow matching."
)
parser.add_argument(
"--use-linear-quadratic-schedule",
action="store_true",
help=
"Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
help="Use linear quadratic schedule for flow matching. Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
)
parser.add_argument(
"--linear-schedule-end",
Expand All @@ -232,20 +322,17 @@ def main(args):
# Model parameters
parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill")
parser.add_argument("--latent-channels", type=int, default=16)
parser.add_argument("--precision",
type=str,
default="bf16",
choices=["fp32", "fp16", "bf16", "fp8"])
parser.add_argument("--rope-theta",
type=int,
default=256,
help="Theta used in RoPE.")
parser.add_argument(
"--precision", type=str, default="bf16", choices=["fp32", "fp16", "bf16", "fp8"]
)
parser.add_argument(
"--rope-theta", type=int, default=256, help="Theta used in RoPE."
)

parser.add_argument("--vae", type=str, default="884-16c-hy")
parser.add_argument("--vae-precision",
type=str,
default="fp16",
choices=["fp32", "fp16", "bf16"])
parser.add_argument(
"--vae-precision", type=str, default="fp16", choices=["fp32", "fp16", "bf16"]
)
parser.add_argument("--vae-tiling", action="store_true", default=True)

parser.add_argument("--text-encoder", type=str, default="llm")
Expand All @@ -258,12 +345,10 @@ def main(args):
parser.add_argument("--text-states-dim", type=int, default=4096)
parser.add_argument("--text-len", type=int, default=256)
parser.add_argument("--tokenizer", type=str, default="llm")
parser.add_argument("--prompt-template",
type=str,
default="dit-llm-encode")
parser.add_argument("--prompt-template-video",
type=str,
default="dit-llm-encode-video")
parser.add_argument("--prompt-template", type=str, default="dit-llm-encode")
parser.add_argument(
"--prompt-template-video", type=str, default="dit-llm-encode-video"
)
parser.add_argument("--hidden-state-skip-layer", type=int, default=2)
parser.add_argument("--apply-final-norm", action="store_true")

Expand All @@ -279,4 +364,7 @@ def main(args):
parser.add_argument("--text-len-2", type=int, default=77)

args = parser.parse_args()
main(args)
if args.quantization:
inference_quantization(args)
else:
inference(args)
Loading