Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLX_UNET.py #53

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 24 additions & 23 deletions ml_mdm/clis/generate_sample.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All rights reserved.
import argparse
import logging
import os
import shlex
import time
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple

import gradio as gr
import simple_parsing
Expand Down Expand Up @@ -36,7 +37,7 @@
)


def dividable(n):
def dividable(n: int) -> Tuple[int, int]:
for i in range(int(np.sqrt(n)), 0, -1):
if n % i == 0:
break
Expand All @@ -51,7 +52,7 @@ def generate_lm_outputs(device, sample, tokenizer, language_model, args):
return sample


def setup_models(args, device):
def setup_models(args: argparse.Namespace, device: torch.device):
input_channels = 3

# load the language model
Expand Down Expand Up @@ -112,32 +113,32 @@ def stop_run():
)


def get_model_type(config_file):
def get_model_type(config_file: str):
with open(config_file, "r") as f:
d = yaml.safe_load(f)
return d.get("model", d.get("vision_model", "unet"))


def generate(
config_file="cc12m_64x64.yaml",
ckpt_name="vis_model_64x64.pth",
prompt="a chair",
input_template="",
negative_prompt="",
negative_template="",
batch_size=20,
guidance_scale=7.5,
threshold_function="clip",
num_inference_steps=250,
eta=0,
save_diffusion_path=False,
show_diffusion_path=False,
show_xt=False,
reader_config="",
seed=10,
comment="",
override_args="",
output_inner=False,
config_file: str = "cc12m_64x64.yaml",
ckpt_name: str = "vis_model_64x64.pth",
prompt: str = "a chair",
input_template: str = "",
negative_prompt: str = "",
negative_template: str = "",
batch_size: int = 20,
guidance_scale: float = 7.5,
threshold_function: str = "clip",
num_inference_steps: int = 250,
eta: int = 0,
save_diffusion_path: bool = False,
show_diffusion_path: bool = False,
show_xt: bool = False,
reader_config: str = "",
seed: int = 10,
comment: str = "",
override_args: str = "",
output_inner: bool = False,
):
np.random.seed(seed)
torch.random.manual_seed(seed)
Expand Down
55 changes: 55 additions & 0 deletions ml_mdm/models/mlx_model_ema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All rights reserved.
import logging
from copy import deepcopy

import mlx.core as mx
import mlx.nn as nn

from ml_mdm.utils import fix_old_checkpoints


class ModelEma(nn.Module):
def __init__(self, model, decay=0.9999, warmup_steps=0, device=None):
super(ModelEma, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if self.device is not None:
self.module.to(device=device)
self.warmup_steps = warmup_steps
self.counter = 0

def update(self, model):
decay = (self.counter >= self.warmup_steps) * self.decay
self.counter += 1
with mx.no_grad():
msd = model.state_dict()
for k, ema_v in self.module.state_dict().items():
model_v = msd[k].detach()
if self.device:
model_v = model_v.to(device=self.device)
ema_v.mul_(decay).add_(model_v, alpha=(1.0 - decay))

def save(self, fname, other_items=None):
logging.info(f"Saving EMA model file: {fname}")
checkpoint = {"state_dict": self.module.state_dict()}
if other_items is not None:
for k, v in other_items.items():
checkpoint[k] = v
mx.save(checkpoint, fname)

def load(self, fname):
logging.info(f"Loading EMA model file: {fname}")
fix_old_checkpoints.mimic_old_modules()
checkpoint = mx.load(fname, map_location=lambda storage, loc: storage)
new_state_dict = self.module.state_dict()
filtered_state_dict = {
key: value
for key, value in checkpoint["state_dict"].items()
if key in new_state_dict
}
self.module.load_state_dict(filtered_state_dict, strict=False)
del checkpoint
Loading