From 5142a4243566918c57bc28e15b711e02da13eec5 Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Wed, 11 Sep 2024 18:41:20 -0700 Subject: [PATCH] Fix NaN bug in normalization for fp16 --- tripy/examples/diffusion/clip_model.py | 17 +++---- tripy/examples/diffusion/example.py | 58 +++++++++++++---------- tripy/examples/diffusion/helper.py | 2 +- tripy/examples/diffusion/model.py | 6 +-- tripy/examples/diffusion/unet_model.py | 38 ++++++++------- tripy/examples/diffusion/vae_model.py | 25 +++++----- tripy/examples/diffusion/weight_loader.py | 15 ++---- 7 files changed, 84 insertions(+), 77 deletions(-) diff --git a/tripy/examples/diffusion/clip_model.py b/tripy/examples/diffusion/clip_model.py index 98a3d2008..1c41d5f9c 100644 --- a/tripy/examples/diffusion/clip_model.py +++ b/tripy/examples/diffusion/clip_model.py @@ -29,7 +29,7 @@ class CLIPConfig: num_heads: int = 12 max_seq_len: int = 77 num_hidden_layers: int = 12 - dtype: tp.dtype = tp.float16 + dtype: tp.dtype = tp.float32 class CLIPMLP(tp.Module): def __init__(self, config: CLIPConfig): @@ -52,6 +52,7 @@ def __init__(self, config: CLIPConfig): self.v_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) self.q_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) self.out_proj = tp.Linear(self.embed_dim, self.embed_dim, dtype=config.dtype) + self.dtype = config.dtype def __call__(self, hidden_states, causal_attention_mask): bsz, tgt_len, embed_dim = hidden_states.shape[0], hidden_states.shape[1], hidden_states.shape[2] @@ -65,7 +66,7 @@ def __call__(self, hidden_states, causal_attention_mask): for x in (q, k, v) ] attn_output = scaled_dot_product_attention( - q, k, v, embedding_dim=self.head_dim, attn_mask=causal_attention_mask + q, k, v, embedding_dim=self.head_dim, attn_mask=causal_attention_mask, dtype=self.dtype, ) out = self.out_proj(tp.reshape(tp.transpose(attn_output, 1, 2), (bsz, tgt_len, embed_dim))) return out @@ -74,18 +75,18 @@ def __call__(self, hidden_states, causal_attention_mask): class CLIPEncoderLayer(tp.Module): def __init__(self, config: CLIPConfig): self.self_attn = CLIPAttention(config) - self.layer_norm1 = tp.LayerNorm(config.embedding_size, dtype=config.dtype) + self.layer_norm1 = tp.LayerNorm(config.embedding_size, dtype=tp.float32) self.mlp = CLIPMLP(config) - self.layer_norm2 = tp.LayerNorm(config.embedding_size, dtype=config.dtype) + self.layer_norm2 = tp.LayerNorm(config.embedding_size, dtype=tp.float32) def __call__(self, hidden_states, causal_attention_mask): residual = hidden_states - hidden_states = self.layer_norm1(hidden_states) + hidden_states = tp.cast(self.layer_norm1(tp.cast(hidden_states, self.layer_norm1.dtype)), hidden_states.dtype) hidden_states = self.self_attn(hidden_states, causal_attention_mask) hidden_states = residual + hidden_states residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) + hidden_states = tp.cast(self.layer_norm2(tp.cast(hidden_states, self.layer_norm2.dtype)), hidden_states.dtype) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states @@ -115,10 +116,10 @@ class CLIPTextTransformer(tp.Module): def __init__(self, config: CLIPConfig): self.embeddings = CLIPTextEmbeddings(config) self.encoder = CLIPEncoder(config) - self.final_layer_norm = tp.LayerNorm(config.embedding_size, dtype=config.dtype) + self.final_layer_norm = tp.LayerNorm(config.embedding_size, dtype=tp.float32) self.max_seq_len = config.max_seq_len def __call__(self, input_ids): x = self.embeddings(input_ids, tp.reshape(tp.iota((input_ids.shape[1],), dtype=tp.int32), (1, -1))) x = self.encoder(x, tp.triu(tp.full((1, 1, self.max_seq_len, self.max_seq_len), float("-inf")), 1)) - return self.final_layer_norm(x) \ No newline at end of file + return tp.cast(self.final_layer_norm(tp.cast(x, self.final_layer_norm.dtype)), x.dtype) \ No newline at end of file diff --git a/tripy/examples/diffusion/example.py b/tripy/examples/diffusion/example.py index de7a3ca61..8f1bd82d5 100644 --- a/tripy/examples/diffusion/example.py +++ b/tripy/examples/diffusion/example.py @@ -52,7 +52,7 @@ def compile_clip(model, dtype=tp.int32, verbose=False): return compile_model(model, inputs, verbose=verbose) -def compile_unet(model, dtype=tp.float16, verbose=False): +def compile_unet(model, dtype, verbose=False): unconditional_context_shape = (1, 77, 768) conditional_context_shape = (1, 77, 768) latent_shape = (1, 4, 64, 64) @@ -68,16 +68,16 @@ def compile_unet(model, dtype=tp.float16, verbose=False): return compile_model(model, inputs, verbose=verbose) -def compile_vae(model, dtype=tp.float16, verbose=False): +def compile_vae(model, dtype, verbose=False): inputs = (tp.InputInfo((1, 4, 64, 64), dtype=dtype),) return compile_model(model, inputs, verbose=verbose) -def run_diffusion_loop(model, unconditional_context, context, latent, steps, guidance): +def run_diffusion_loop(model, unconditional_context, context, latent, steps, guidance, dtype): timesteps = list(range(1, 1000, 1000 // steps)) - print(f"[I] Running diffusion for {timesteps} timesteps...") - alphas = get_alphas_cumprod()[tp.Tensor(timesteps)] - alphas_prev = tp.concatenate([tp.Tensor([1.0]), alphas[:-1]], dim=0) + print(f"[I] Running diffusion for {steps} timesteps...") + alphas = get_alphas_cumprod(dtype=dtype)[tp.Tensor(timesteps)] + alphas_prev = tp.concatenate([tp.Tensor([1.0], dtype=dtype), alphas[:-1]], dim=0) for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])): t.set_description("idx: %1d, timestep: %3d" % (index, timestep)) @@ -86,10 +86,10 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui unconditional_context, context, latent, - tp.cast(tp.Tensor([timestep]), tp.float32), + tp.Tensor([timestep], dtype=dtype), alphas[tid], alphas_prev[tid], - tp.Tensor([guidance]), + tp.Tensor([guidance], dtype=dtype), ) return latent @@ -97,21 +97,23 @@ def run_diffusion_loop(model, unconditional_context, context, latent, steps, gui def tripy_diffusion(args): run_start_time = time.perf_counter() - if os.path.isdir("engines"): + dtype, torch_dtype = (tp.float16, torch.float16) if args.fp16 else (tp.float32, torch.float32) + + if os.path.isdir(args.engine_dir): print("[I] Loading cached engines from disk...") clip_compiled = tp.Executable.load(os.path.join("engines", "clip_executable.json")) unet_compiled = tp.Executable.load(os.path.join("engines", "unet_executable.json")) vae_compiled = tp.Executable.load(os.path.join("engines", "vae_executable.json")) else: - model = StableDiffusion(StableDiffusionConfig(dtype=tp.float16)) + model = StableDiffusion(StableDiffusionConfig(dtype=dtype)) print("[I] Loading model weights...", flush=True) - load_from_diffusers(model, tp.float16, debug=True) + load_from_diffusers(model, dtype, args.hf_token, debug=True) clip_compiled = compile_clip(model.cond_stage_model.transformer.text_model, verbose=True) - unet_compiled = compile_unet(model, verbose=True) - vae_compiled = compile_vae(model.decode, verbose=True) + unet_compiled = compile_unet(model, dtype, verbose=True) + vae_compiled = compile_vae(model.decode, dtype, verbose=True) - os.mkdir("engines") - print("[I] Saving engines to disk...") + os.mkdir(args.engine_dir) + print(f"[I] Saving engines to {args.engine_dir}...") clip_compiled.save(os.path.join("engines", "clip_executable.json")) unet_compiled.save(os.path.join("engines", "unet_executable.json")) vae_compiled.save(os.path.join("engines", "vae_executable.json")) @@ -135,11 +137,11 @@ def tripy_diffusion(args): # Backbone of diffusion - the UNet if args.seed is not None: torch.manual_seed(args.seed) - torch_latent = torch.randn((1, 4, 64, 64)).to("cuda") + torch_latent = torch.randn((1, 4, 64, 64), dtype=torch_dtype).to("cuda") latent = tp.Tensor(torch_latent) diffusion_run_start = time.perf_counter() - latent = run_diffusion_loop(unet_compiled, unconditional_context, context, latent, args.steps, args.guidance) + latent = run_diffusion_loop(unet_compiled, unconditional_context, context, latent, args.steps, args.guidance, dtype) diffusion_run_end = time.perf_counter() print(f"[I] Finished diffusion denoising. Inference took {diffusion_run_end - diffusion_run_start} seconds.") @@ -173,15 +175,17 @@ def hf_diffusion(args): run_start_time = time.perf_counter() + dtype = torch.float16 if args.fp16 else torch.float32 + model_opts = {'variant': 'fp16', 'torch_dtype': torch.float16} if args.fp16 else {} + # Initialize models - model_id = "CompVis/stable-diffusion-v1-4" #"benjamin-paine/stable-diffusion-v1-5" #"runwayml/stable-diffusion-v1-5" - clip_id = "openai/clip-vit-large-patch14" + model_id = "KiwiXR/stable-diffusion-v1-5" print("[I] Loading models...") - hf_tokenizer = CLIPTokenizer.from_pretrained(clip_id) - hf_encoder = CLIPTextModel.from_pretrained(clip_id).to("cuda") - unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to("cuda") - vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to("cuda") + hf_tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") + hf_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to("cuda") + unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", use_auth_token=args.hf_token, **model_opts).to("cuda") + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", use_auth_token=args.hf_token, **model_opts).to("cuda") scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) # Run through CLIP to get context from prompt @@ -192,19 +196,20 @@ def hf_diffusion(args): uncond_input = hf_tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt").to("cuda") text_embeddings = hf_encoder(text_input.input_ids, output_hidden_states=True)[0] uncond_embeddings = hf_encoder(uncond_input.input_ids)[0] - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype) clip_run_end = time.perf_counter() print(f"took {clip_run_end - clip_run_start} seconds.") # Backbone of diffusion - the UNet if args.seed is not None: torch.manual_seed(args.seed) - torch_latent = torch.randn((1, 4, 64, 64)).to("cuda") + torch_latent = torch.randn((1, 4, 64, 64), dtype=dtype).to("cuda") torch_latent *= scheduler.init_noise_sigma scheduler.set_timesteps(args.steps) diffusion_run_start = time.perf_counter() + print(f"[I] Running diffusion for {args.steps} timesteps...") for t in tqdm(scheduler.timesteps): # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([torch_latent] * 2) @@ -267,7 +272,6 @@ def print_summary(denoising_steps, times): # TODO: Add torch compilation modes -# TODO: Add fp16 support # TODO: Add Timing context def main(): default_prompt = "a horse sized cat eating a bagel" @@ -282,6 +286,8 @@ def main(): parser.add_argument("--seed", type=int, help="Set the random latent seed") parser.add_argument("--guidance", type=float, default=7.5, help="Prompt strength") parser.add_argument('--torch-inference', action='store_true', help="Run inference with PyTorch (eager mode) instead of TensorRT.") + parser.add_argument('--hf-token', type=str, default='', help="HuggingFace API access token for downloading model checkpoints") + parser.add_argument('--engine-dir', type=str, default='engines', help="Output directory for TensorRT engines") args = parser.parse_args() if args.torch_inference: diff --git a/tripy/examples/diffusion/helper.py b/tripy/examples/diffusion/helper.py index 9fdbbc0d1..ae4e956e3 100644 --- a/tripy/examples/diffusion/helper.py +++ b/tripy/examples/diffusion/helper.py @@ -12,7 +12,7 @@ def scaled_dot_product_attention( embedding_dim: Optional[int] = None, attn_mask: Optional[tp.Tensor] = None, is_causal: bool = False, - dtype: tp.dtype = tp.float16 + dtype: tp.dtype = tp.float32 ) -> tp.Tensor: """ Computes scaled dot-product attention. diff --git a/tripy/examples/diffusion/model.py b/tripy/examples/diffusion/model.py index 85984f6ee..9fad013bb 100644 --- a/tripy/examples/diffusion/model.py +++ b/tripy/examples/diffusion/model.py @@ -33,7 +33,7 @@ @dataclass class StableDiffusionConfig: - dtype: tp.dtype = tp.float16 + dtype: tp.dtype = tp.float32 clip_config: Optional[CLIPConfig] = field(default=None, init=False) unet_config: Optional[UNetConfig] = field(default=None, init=False) vae_config: Optional[VAEConfig] = field(default=None, init=False) @@ -44,11 +44,11 @@ def __post_init__(self): self.vae_config = VAEConfig(dtype=self.dtype) # equivalent to LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) -def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000): +def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000, dtype=tp.float32): betas = np.linspace(beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32) ** 2 alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) - return tp.Tensor(alphas_cumprod) + return tp.cast(tp.Tensor(alphas_cumprod), dtype) class StableDiffusion(tp.Module): diff --git a/tripy/examples/diffusion/unet_model.py b/tripy/examples/diffusion/unet_model.py index fb91f2249..a18bd3d71 100644 --- a/tripy/examples/diffusion/unet_model.py +++ b/tripy/examples/diffusion/unet_model.py @@ -18,6 +18,7 @@ import math from typing import List, Tuple +import torch import tripy as tp from dataclasses import dataclass @@ -33,28 +34,30 @@ class UNetConfig: num_heads: int = 8 context_dim: int = 768 emb_channels: int = 1280 - dtype: tp.dtype = tp.float16 + dtype: tp.dtype = tp.float32 # Used for UNet, not to be confused with ResnetBlock, called ResnetBlock2D in HF diffusers class ResBlock(tp.Module): def __init__(self, config: UNetConfig, channels, emb_channels, out_channels): - self.norm1 = tp.GroupNorm(32, channels, dtype=config.dtype) + self.norm1 = tp.GroupNorm(32, channels, dtype=tp.float32) self.conv1 = tp.Conv(channels, out_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype) self.time_emb_proj = tp.Linear(emb_channels, out_channels, dtype=config.dtype) - self.norm2 = tp.GroupNorm(32, out_channels, dtype=config.dtype) + self.norm2 = tp.GroupNorm(32, out_channels, dtype=tp.float32) self.conv2 = tp.Conv(out_channels, out_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype) self.nonlinearity = tp.silu self.conv_shortcut = tp.Conv(channels, out_channels, (1, 1), dtype=config.dtype) if channels != out_channels else lambda x: x def __call__(self, x, emb): - h = self.conv1(self.nonlinearity(self.norm1(x))) + h = tp.cast(self.norm1(tp.cast(x, self.norm1.dtype)), x.dtype) + h = self.conv1(self.nonlinearity(h)) emb_out = self.time_emb_proj(self.nonlinearity(emb)) target_shape = emb_out.shape + (1, 1) # TODO: #228: WAR to prevent computing output rank in infer_rank for reshape target_shape.trace_tensor.shape = (emb_out.rank + 2,) h = h + tp.reshape(emb_out, target_shape) - h = self.conv2(self.nonlinearity(self.norm2(h))) + h = tp.cast(self.norm2(tp.cast(h, self.norm2.dtype)), h.dtype) + h = self.conv2(self.nonlinearity(h)) ret = self.conv_shortcut(x) + h return ret @@ -67,6 +70,7 @@ def __init__(self, config: UNetConfig, query_dim, context_dim, n_heads, d_head): self.num_heads = n_heads self.head_size = d_head self.to_out = [tp.Linear(n_heads * d_head, query_dim, dtype=config.dtype)] + self.dtype = config.dtype def __call__(self, x, context=None): context = x if context is None else context @@ -74,7 +78,7 @@ def __call__(self, x, context=None): q, k, v = [ tp.transpose(tp.reshape(y, (x.shape[0], -1, self.num_heads, self.head_size)), 1, 2) for y in (q, k, v) ] - attention = tp.transpose(scaled_dot_product_attention(q, k, v, embedding_dim=self.head_size), 1, 2) + attention = tp.transpose(scaled_dot_product_attention(q, k, v, embedding_dim=self.head_size, dtype=self.dtype), 1, 2) h_ = tp.reshape(attention, (x.shape[0], -1, self.num_heads * self.head_size)) out = sequential(h_, self.to_out) return out @@ -116,20 +120,20 @@ def __init__(self, config, dim, context_dim, n_heads, d_head): self.attn1 = CrossAttention(config, dim, dim, n_heads, d_head) self.ff = FeedForward(config, dim) self.attn2 = CrossAttention(config, dim, context_dim, n_heads, d_head) - self.norm1 = tp.LayerNorm(dim, dtype=config.dtype) - self.norm2 = tp.LayerNorm(dim, dtype=config.dtype) - self.norm3 = tp.LayerNorm(dim, dtype=config.dtype) + self.norm1 = tp.LayerNorm(dim, dtype=tp.float32) + self.norm2 = tp.LayerNorm(dim, dtype=tp.float32) + self.norm3 = tp.LayerNorm(dim, dtype=tp.float32) def __call__(self, x, context=None): - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x + x = self.attn1(tp.cast(self.norm1(tp.cast(x, self.norm1.dtype)), x.dtype)) + x + x = self.attn2(tp.cast(self.norm2(tp.cast(x, self.norm2.dtype)), x.dtype), context=context) + x + x = self.ff(tp.cast(self.norm3(tp.cast(x, self.norm3.dtype)), x.dtype)) + x return x class SpatialTransformer(tp.Module): # Transformer2dModel in HF diffusers def __init__(self, config: UNetConfig, channels, context_dim, n_heads, d_head): - self.norm = tp.GroupNorm(32, channels, dtype=config.dtype) + self.norm = tp.GroupNorm(32, channels, dtype=tp.float32) assert channels == n_heads * d_head self.proj_in = tp.Conv(channels, n_heads * d_head, (1, 1), dtype=config.dtype) self.transformer_blocks = [BasicTransformerBlock(config, channels, context_dim, n_heads, d_head)] @@ -138,7 +142,7 @@ def __init__(self, config: UNetConfig, channels, context_dim, n_heads, d_head): def __call__(self, x, context=None): b, c, h, w = x.shape x_in = x - x = self.norm(x) + x = tp.cast(self.norm(tp.cast(x, self.norm.dtype)), x.dtype) x = self.proj_in(x) x = tp.permute(tp.reshape(x, (b, c, h * w)), (0, 2, 1)) for block in self.transformer_blocks: @@ -272,7 +276,7 @@ def __init__(self, config: UNetConfig): CrossAttnUpBlock2D(config, up_channels[2:5], down_channels[2]), CrossAttnUpBlock2D(config, up_channels[4:7], down_channels[1], use_upsampler=False), ] - self.conv_norm_out = tp.GroupNorm(32, config.model_channels, dtype=config.dtype) + self.conv_norm_out = tp.GroupNorm(32, config.model_channels, dtype=tp.float32) self.conv_act = tp.silu self.conv_out = tp.Conv(config.model_channels, config.io_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype) @@ -280,7 +284,6 @@ def __call__(self, x, timesteps=None, context=None): # TODO: real time embedding t_emb = timestep_embedding(timesteps, self.config.model_channels, self.config.dtype) emb = self.time_embedding(t_emb) - x = self.conv_in(x) saved_inputs = [x] @@ -301,6 +304,7 @@ def __call__(self, x, timesteps=None, context=None): else: x = block(x, emb, context, partial_inputs) - act = self.conv_out(self.conv_act(self.conv_norm_out(x))) + act = tp.cast(self.conv_norm_out(tp.cast(x, self.conv_norm_out.dtype)), x.dtype) + act = self.conv_out(self.conv_act(act)) return act diff --git a/tripy/examples/diffusion/vae_model.py b/tripy/examples/diffusion/vae_model.py index cc19103f9..111bdb333 100644 --- a/tripy/examples/diffusion/vae_model.py +++ b/tripy/examples/diffusion/vae_model.py @@ -29,28 +29,29 @@ class VAEConfig: model_channel: int = 128 channel_mult_encode: Tuple[int] = (1, 1, 2, 4, 4) channel_mult_decode: Tuple[int] = (4, 4, 4, 2, 1) - dtype: tp.dtype = tp.float16 + dtype: tp.dtype = tp.float32 class AttnBlock(tp.Module): def __init__(self, config: VAEConfig, in_channels): - self.group_norm = tp.GroupNorm(32, in_channels, dtype=config.dtype) + self.group_norm = tp.GroupNorm(32, in_channels, dtype=tp.float32) self.to_q = tp.Linear(in_channels, in_channels, dtype=config.dtype) self.to_k = tp.Linear(in_channels, in_channels, dtype=config.dtype) self.to_v = tp.Linear(in_channels, in_channels, dtype=config.dtype) self.to_out = [tp.Linear(in_channels, in_channels, dtype=config.dtype)] self.in_channels = in_channels + self.dtype = config.dtype # adapted from AttnBlock in ldm repo def __call__(self, x): - h_ = self.group_norm(x) + h_ = tp.cast(self.group_norm(tp.cast(x, self.group_norm.dtype)), x.dtype) b, c, h, w = h_.shape h_flat = tp.transpose(tp.reshape(h_, (b, c, h * w)), 1, 2) q, k, v = self.to_q(h_flat), self.to_k(h_flat), self.to_v(h_flat) # compute attention - h_ = scaled_dot_product_attention(q, k, v, embedding_dim=self.in_channels) + h_ = scaled_dot_product_attention(q, k, v, embedding_dim=self.in_channels, dtype=self.dtype) out = tp.reshape( tp.transpose(self.to_out[0](h_), 1, 2), (b, c, h, w), @@ -60,16 +61,16 @@ def __call__(self, x): # Not to be confused with ResBlock. Called ResnetBlock2D in HF diffusers class ResnetBlock(tp.Module): def __init__(self, config: VAEConfig, in_channels, out_channels=None): - self.norm1 = tp.GroupNorm(32, in_channels, dtype=config.dtype) + self.norm1 = tp.GroupNorm(32, in_channels, dtype=tp.float32) self.conv1 = tp.Conv(in_channels, out_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype) - self.norm2 = tp.GroupNorm(32, out_channels, dtype=config.dtype) + self.norm2 = tp.GroupNorm(32, out_channels, dtype=tp.float32) self.conv2 = tp.Conv(out_channels, out_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype) self.nonlinearity = tp.silu self.conv_shortcut = tp.Conv(in_channels, out_channels, (1, 1), dtype=config.dtype) if in_channels != out_channels else lambda x: x def __call__(self, x): - h = self.conv1(self.nonlinearity(self.norm1(x))) - h = self.conv2(self.nonlinearity(self.norm2(h))) + h = self.conv1(self.nonlinearity(tp.cast(self.norm1(tp.cast(x, self.norm1.dtype)), x.dtype))) + h = self.conv2(self.nonlinearity(tp.cast(self.norm2(tp.cast(h, self.norm2.dtype)), h.dtype))) return self.conv_shortcut(x) + h class Downsample(tp.Module): @@ -122,7 +123,7 @@ def __init__(self, config: VAEConfig): self.conv_in = tp.Conv(config.latent_channels, config.model_channel * config.channel_mult_decode[0], (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype) self.up_blocks = [UpDecoderBlock2D(config, up_channels[i], up_channels[i+1], use_upsampler=upsamplers[i]) for i in range(num_resolutions)] self.mid_block = Mid(config, up_channels[0]) - self.conv_norm_out = tp.GroupNorm(32, config.model_channel, dtype=config.dtype) + self.conv_norm_out = tp.GroupNorm(32, config.model_channel, dtype=tp.float32) self.conv_act = tp.silu self.conv_out = tp.Conv(config.model_channel, config.io_channels, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype) @@ -132,7 +133,7 @@ def __call__(self, x): for up_block in self.up_blocks: x = up_block(x) - return self.conv_out(self.conv_act(self.conv_norm_out(x))) + return self.conv_out(self.conv_act(tp.cast(self.conv_norm_out(tp.cast(x, self.conv_norm_out.dtype)), x.dtype))) class DownEncoderBlock2D(tp.Module): def __init__(self, config: VAEConfig, start_channels, channels, use_downsampler=True): @@ -156,7 +157,7 @@ def __init__(self, config: VAEConfig): self.conv_in = tp.Conv(config.io_channels, config.model_channel, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype) self.down_blocks = [DownEncoderBlock2D(config, down_channels[i], down_channels[i+1], use_downsampler=downsamplers[i]) for i in range(num_resolutions)] self.mid_block = Mid(config, down_channels[-1]) - self.conv_norm_out = tp.GroupNorm(32, down_channels[-1], dtype=config.dtype) + self.conv_norm_out = tp.GroupNorm(32, down_channels[-1], dtype=tp.float32) self.conv_act = tp.silu self.conv_out = tp.Conv(down_channels[-1], 8, (3, 3), padding=((1, 1), (1, 1)), dtype=config.dtype) @@ -165,7 +166,7 @@ def __call__(self, x): for i in range(len(self.down_blocks)): x = self.down_blocks[i](x) x = self.mid_block(x) - return self.conv_out(self.conv_act(self.conv_norm_out(x))) + return self.conv_out(self.conv_act(tp.cast(self.conv_norm_out(tp.cast(x, self.conv_norm_out.dtype)), x.dtype))) class AutoencoderKL(tp.Module): diff --git a/tripy/examples/diffusion/weight_loader.py b/tripy/examples/diffusion/weight_loader.py index 14d98bedd..3fa3b0557 100644 --- a/tripy/examples/diffusion/weight_loader.py +++ b/tripy/examples/diffusion/weight_loader.py @@ -20,21 +20,16 @@ def load_weights_from_hf(model, hf_model, dtype, debug=False): torch_dtype = getattr(torch, dtype.name) for key in hf_keys: weight = hf_state_dict[key] - # print(weight.dtype) - # if "ln" in key or "gn" in key or "norm" in key: - # print(f"{key}: {weight.dtype}") - # if "norm" not in key: - # weight = weight.to(torch_dtype) - # print(f"{key}: {weight.dtype}") + if "norm" not in key: + weight = weight.to(torch_dtype) param = tp.Parameter(weight) tripy_state_dict[key.removeprefix("text_model.")] = param model.load_from_state_dict(tripy_state_dict) -def load_from_diffusers(model, dtype, debug=False): - model_id = "CompVis/stable-diffusion-v1-4" #"benjamin-paine/stable-diffusion-v1-5" #"runwayml/stable-diffusion-v1-5" - model_opts = {'variant': 'fp16', 'torch_dtype': torch.float16} if dtype == tp.float16 else {} - pipe = StableDiffusionPipeline.from_pretrained(model_id, **model_opts) +def load_from_diffusers(model, dtype, hf_token, debug=False): + model_id = "KiwiXR/stable-diffusion-v1-5" + pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=hf_token) load_weights_from_hf(model.cond_stage_model.transformer.text_model, pipe.text_encoder, dtype, debug=debug) load_weights_from_hf(model.model.diffusion_model, pipe.unet, dtype, debug=debug)