Skip to content

Commit

Permalink
remove irrelevant code
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 19, 2021
1 parent 5302a54 commit 816a04b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 39 deletions.
42 changes: 5 additions & 37 deletions deep_daze/clip.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torchvision.transforms import Normalize

import hashlib
import os
Expand Down Expand Up @@ -160,20 +160,13 @@ def _download(url, root = os.path.expanduser("~/.cache/clip")):

return download_target

normalize_image = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

def load(device = ("cuda" if torch.cuda.is_available() else "cpu")):
def load():
device = 'cuda'
model_path = _download(MODEL_PATH)
model = torch.jit.load(model_path, map_location = device).eval()
n_px = model.input_resolution.item()

transform = Compose([
Resize(n_px, interpolation=Image.BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
normalize_image,
])
normalize_image = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
Expand All @@ -192,32 +185,7 @@ def patch_device(module):
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)

# patch dtype to float32 on CPU
if device == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()

def patch_float(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)

for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)

model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)

model.float()

return model, transform
return model, normalize_image

tokenizer = SimpleTokenizer()

Expand Down
4 changes: 2 additions & 2 deletions deep_daze/deep_daze.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torchvision
from torchvision.utils import save_image

from deep_daze.clip import load, tokenize, normalize_image
from deep_daze.clip import load, tokenize
from siren_pytorch import SirenNet, SirenWrapper

import signal
Expand Down Expand Up @@ -46,7 +46,7 @@ def rand_cutout(image, size):

# load clip

perceptor, preprocess = load()
perceptor, normalize_image = load()

# load siren

Expand Down

0 comments on commit 816a04b

Please sign in to comment.