Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mps-cpu Support #71

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Binary file added .DS_Store
Binary file not shown.
64 changes: 64 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,16 @@ Note that the way we connect layers is computational efficient. The original SD

First create a new conda environment

CUDA, CPU

conda env create -f environment.yaml
conda activate control

MPS

conda env create -f environment-mps.yaml
conda activate control

All models and detectors can be downloaded from [our Hugging Face page](https://huggingface.co/lllyasviel/ControlNet). Make sure that SD models are put in "ControlNet/models" and detectors are put in "ControlNet/annotator/ckpts". Make sure that you download all necessary pretrained weights and detector models from that Hugging Face page, including HED edge detection model, Midas depth estimation model, Openpose, and so on.

We provide 9 Gradio apps with these models.
Expand All @@ -63,8 +70,14 @@ All test images can be found at the folder "test_imgs".

Stable Diffusion 1.5 + ControlNet (using simple Canny edge detection)

##### CUDA, CPU

python gradio_canny2image.py

##### MPS

PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_canny2image.py

The Gradio app also allows you to change the Canny edge thresholds. Just try it for more details.

Prompt: "bird"
Expand All @@ -77,8 +90,14 @@ Prompt: "cute dog"

Stable Diffusion 1.5 + ControlNet (using simple M-LSD straight line detection)

##### CUDA, CPU

python gradio_hough2image.py

#### MPS

PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_hough2image.py

The Gradio app also allows you to change the M-LSD thresholds. Just try it for more details.

Prompt: "room"
Expand All @@ -91,8 +110,14 @@ Prompt: "building"

Stable Diffusion 1.5 + ControlNet (using soft HED Boundary)

#### CUDA, CPU

python gradio_hed2image.py

#### MPS

PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_hed2image.py

The soft HED Boundary will preserve many details in input images, making this app suitable for recoloring and stylizing. Just try it for more details.

Prompt: "oil painting of handsome old man, masterpiece"
Expand All @@ -105,8 +130,14 @@ Prompt: "Cyberpunk robot"

Stable Diffusion 1.5 + ControlNet (using Scribbles)

#### CUDA, CPU

python gradio_scribble2image.py

#### MPS

PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_scribble2image.py

Note that the UI is based on Gradio, and Gradio is somewhat difficult to customize. Right now you need to draw scribbles outside the UI (using your favorite drawing software, for example, MS Paint) and then import the scribble image to Gradio.

Prompt: "turtle"
Expand All @@ -119,8 +150,14 @@ Prompt: "hot air balloon"

We actually provide an interactive interface

#### CUDA, CPU

python gradio_scribble2image_interactive.py

#### MPS

PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_scribble2image.py

~~However, because gradio is very [buggy](https://github.com/gradio-app/gradio/issues/3166) and difficult to customize, right now, user need to first set canvas width and heights and then click "Open drawing canvas" to get a drawing area. Please do not upload image to that drawing canvas. Also, the drawing area is very small; it should be bigger. But I failed to find out how to make it larger. Again, gradio is really buggy.~~ (Now fixed, will update asap)

The below dog sketch is drawn by me. Perhaps we should draw a better dog for showcase.
Expand All @@ -132,8 +169,14 @@ Prompt: "dog in a room"

Stable Diffusion 1.5 + ControlNet (using fake scribbles)

#### CUDA, CPU

python gradio_fake_scribble2image.py

#### MPS

PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_fake_scribble2image.py

Sometimes we are lazy, and we do not want to draw scribbles. This script use the exactly same scribble-based model but use a simple algorithm to synthesize scribbles from input images.

Prompt: "bag"
Expand All @@ -146,8 +189,12 @@ Prompt: "shose" (Note that "shose" is a typo; it should be "shoes". But it still

Stable Diffusion 1.5 + ControlNet (using human pose)

#### CUDA, CPU

python gradio_pose2image.py

#### MPS
PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_pose2image.py
Apparently, this model deserves a better UI to directly manipulate pose skeleton. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then the Openpose will detect the pose for you.

Prompt: "Chief in the kitchen"
Expand All @@ -160,8 +207,13 @@ Prompt: "An astronaut on the moon"

Stable Diffusion 1.5 + ControlNet (using semantic segmentation)

#### CUDA, CPU

python gradio_seg2image.py

#### MPS
Not Supported (Reason:aten::_slow_conv2d_forward is currently not supported by mps.)

This model use ADE20K's segmentation protocol. Again, this model deserves a better UI to directly draw the segmentations. However, again, Gradio is somewhat difficult to customize. Right now you need to input an image and then a model called Uniformer will detect the segmentations for you. Just try it for more details.

Prompt: "House"
Expand All @@ -174,8 +226,14 @@ Prompt: "River"

Stable Diffusion 1.5 + ControlNet (using depth map)

#### CUDA, CPU

python gradio_depth2image.py

### MPS

PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_depth2image.py

Great! Now SD 1.5 also have a depth control. FINALLY. So many possibilities (considering SD1.5 has much more community models than SD2).

Note that different from Stability's model, the ControlNet receive the full 512×512 depth map, rather than 64×64 depth. Note that Stability's SD2 depth model use 64*64 depth maps. This means that the ControlNet will preserve more details in the depth map.
Expand All @@ -189,8 +247,14 @@ Prompt: "Stormtrooper's lecture"

Stable Diffusion 1.5 + ControlNet (using normal map)

#### CUDA, CPU

python gradio_normal2image.py

#### MPS

PYTORCH_ENABLE_MPS_FALLBACK=1 python gradio_normal2image.py

This model use normal map. Rightnow in the APP, the normal is computed from the midas depth map and a user threshold (to determine how many area is background with identity normal face to viewer, tune the "Normal background threshold" in the gradio app to get a feeling).

Prompt: "Cute toy"
Expand Down
9 changes: 5 additions & 4 deletions annotator/hed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,21 @@ def forward(self, tenInput):
return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1))


class HEDdetector:
def __init__(self):
class HEDdetector():
def __init__(self, device):
self.device = device
remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth"
modelpath = os.path.join(annotator_ckpts_path, "network-bsds500.pth")
if not os.path.exists(modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
self.netNetwork = Network(modelpath).cuda().eval()
self.netNetwork = Network(modelpath).to(device).eval()

def __call__(self, input_image):
assert input_image.ndim == 3
input_image = input_image[:, :, ::-1].copy()
with torch.no_grad():
image_hed = torch.from_numpy(input_image).float().cuda()
image_hed = torch.from_numpy(input_image).float().to(self.device)
image_hed = image_hed / 255.0
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
edge = self.netNetwork(image_hed)[0]
Expand Down
7 changes: 4 additions & 3 deletions annotator/midas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@


class MidasDetector:
def __init__(self):
self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
def __init__(self, device):
self.device = device
self.model = MiDaSInference(model_type="dpt_hybrid").to(device)

def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
assert input_image.ndim == 3
image_depth = input_image
with torch.no_grad():
image_depth = torch.from_numpy(image_depth).float().cuda()
image_depth = torch.from_numpy(image_depth).float().to(self.device)
image_depth = image_depth / 127.5 - 1.0
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
depth = self.model(image_depth)[0]
Expand Down
4 changes: 2 additions & 2 deletions annotator/mlsd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@


class MLSDdetector:
def __init__(self):
def __init__(self, device):
model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth")
if not os.path.exists(model_path):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
model = MobileV2_MLSD_Large()
model.load_state_dict(torch.load(model_path), strict=True)
self.model = model.cuda().eval()
self.model = model.to(device).eval()

def __call__(self, input_image, thr_v, thr_d):
assert input_image.ndim == 3
Expand Down
10 changes: 7 additions & 3 deletions cldm/ddim_hacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,20 @@


class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
def __init__(self, model, device, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device(self.device):
if str(self.device) == 'mps':
attr = attr.to(self.device, torch.float32)
else:
attr = attr.to(self.device)
setattr(self, name, attr)

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
Expand Down
4 changes: 3 additions & 1 deletion cldm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ def get_state_dict(d):
return d.get('state_dict', d)


def load_state_dict(ckpt_path, location='cpu'):
def load_state_dict(ckpt_path, location):
_, extension = os.path.splitext(ckpt_path)
if str(location) == "mps":
location = "cpu"
if extension.lower() == ".safetensors":
import safetensors.torch
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
Expand Down
34 changes: 34 additions & 0 deletions environment-mps.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: control
channels:
- pytorch
- defaults
dependencies:
- python=3.8
- pip
- pytorch=1.12.1
- torchvision=0.13.1
- numpy=1.23.1
- pip:
- gradio==3.16.2
- albumentations==1.3.0
- opencv-contrib-python
- imageio==2.9.0
- imageio-ffmpeg==0.4.2
- pytorch-lightning==1.5.0
- omegaconf==2.1.1
- test-tube>=0.7.5
- streamlit==1.12.1
- einops==0.3.0
- transformers==4.19.2
- webdataset==0.2.5
- kornia==0.6
- open_clip_torch==2.0.2
- invisible-watermark>=0.1.5
- streamlit-drawable-canvas==0.8.0
- torchmetrics==0.6.0
- timm==0.6.12
- addict==2.4.0
- yapf==0.32.0
- prettytable==3.6.0
- safetensors==0.2.7
- basicsr==1.4.2
18 changes: 14 additions & 4 deletions gradio_canny2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,22 @@
from cldm.ddim_hacked import DDIMSampler


def get_device():
if(torch.cuda.is_available()):
return 'cuda'
elif(torch.backends.mps.is_available()):
return 'mps'
else:
return 'cpu'


apply_canny = CannyDetector()

device = get_device()
model = create_model('./models/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)
model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location=device))
model = model.to(device)
ddim_sampler = DDIMSampler(model, device)


def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold):
Expand All @@ -31,7 +41,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
detected_map = apply_canny(img, low_threshold, high_threshold)
detected_map = HWC3(detected_map)

control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0
control = torch.stack([control for _ in range(num_samples)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()

Expand Down
20 changes: 15 additions & 5 deletions gradio_depth2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,22 @@
from cldm.ddim_hacked import DDIMSampler


apply_midas = MidasDetector()
def get_device():
if(torch.cuda.is_available()):
return 'cuda'
elif(torch.backends.mps.is_available()):
return 'mps'
else:
return 'cpu'


device = get_device()
apply_midas = MidasDetector(device)

model = create_model('./models/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location='cuda'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)
model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location=device))
model = model.to(device)
ddim_sampler = DDIMSampler(model, device)


def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
Expand All @@ -33,7 +43,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti

detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0
control = torch.stack([control for _ in range(num_samples)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()

Expand Down
Loading