Skip to content

Commit

Permalink
Merge pull request #57 from ZeroCool940711/dev
Browse files Browse the repository at this point in the history
Added support for both 3 and 4 channels images.
  • Loading branch information
ZeroCool940711 authored Jul 12, 2023
2 parents 7713473 + 113af92 commit 80c1a63
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 10 deletions.
5 changes: 4 additions & 1 deletion infer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@
parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate.")
parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.")
parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.")
parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE.")
parser.add_argument(
"--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA."
)
parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.")
parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.")
parser.add_argument(
Expand Down Expand Up @@ -435,6 +437,7 @@ def main():
center_crop=True if not args.no_center_crop and not args.random_crop else False,
flip=not args.no_flip,
random_crop=args.random_crop if args.random_crop else False,
alpha_channel=False if args.channels == 3 else True,
)

if args.input_image and not args.input_folder:
Expand Down
28 changes: 25 additions & 3 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,23 @@ def __init__(
stream=False,
using_taming=False,
random_crop=False,
alpha_channel=True,
):
super().__init__()
self.dataset = dataset
self.image_column = image_column
self.stream = stream
transform_list = [
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Lambda(
lambda img: img.convert("RGBA")
if img.mode != "RGBA" and alpha_channel
else img
if img.mode == "RGB" and not alpha_channel
else img.convert("RGB")
),
T.Resize(image_size),
]

if flip:
transform_list.append(T.RandomHorizontalFlip())
if center_crop and not random_crop:
Expand Down Expand Up @@ -199,7 +207,15 @@ def __getitem__(self, index):

class LocalTextImageDataset(Dataset):
def __init__(
self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False, random_crop=False
self,
path,
image_size,
tokenizer,
flip=True,
center_crop=True,
using_taming=False,
random_crop=False,
alpha_channel=False,
):
super().__init__()
self.tokenizer = tokenizer
Expand Down Expand Up @@ -229,7 +245,13 @@ def __init__(
self.caption_pair.append(captions)

transform_list = [
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Lambda(
lambda img: img.convert("RGBA")
if img.mode != "RGBA" and alpha_channel
else img
if img.mode == "RGB" and not alpha_channel
else img.convert("RGB")
),
T.Resize(image_size),
]
if flip:
Expand Down
9 changes: 5 additions & 4 deletions muse_maskgit_pytorch/vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def forward(self, x):

# discriminator
class Discriminator(nn.Module):
def __init__(self, dims, channels=3, groups=16, init_kernel_size=5):
def __init__(self, dims, channels=4, groups=16, init_kernel_size=5):
super().__init__()
dim_pairs = zip(dims[:-1], dims[1:])

Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
self,
dim: int,
*,
channels=3,
channels=4,
layers=4,
layer_mults=None,
num_resnet_blocks=1,
Expand Down Expand Up @@ -337,7 +337,7 @@ def __init__(
*,
dim: int,
accelerator: Accelerator = None,
channels=3,
channels=4,
layers=4,
l2_recon_loss=False,
use_hinge_loss=True,
Expand Down Expand Up @@ -407,7 +407,8 @@ def vgg(self):
if exists(self._vgg):
return self._vgg

vgg = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
vgg = torchvision.models.vgg16(pretrained=True)
vgg.features[0] = nn.Conv2d(self.channels, 64, kernel_size=3, stride=1, padding=1)
vgg.classifier = nn.Sequential(*vgg.classifier[:-2])
self._vgg = vgg.to(self.device)
return self._vgg
Expand Down
5 changes: 4 additions & 1 deletion train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,9 @@
help="Image Size.",
)
parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.")
parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE.")
parser.add_argument(
"--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA."
)
parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.")
parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.")
parser.add_argument(
Expand Down Expand Up @@ -812,6 +814,7 @@ def main():
flip=False if args.no_flip else True,
using_taming=False if not args.taming_model_path else True,
random_crop=args.random_crop if args.random_crop else False,
alpha_channel=False if args.channels == 3 else True,
)
elif args.link:
if not args.dataset_name:
Expand Down
5 changes: 4 additions & 1 deletion train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@
)
parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.")
parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.")
parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE.")
parser.add_argument(
"--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA."
)
parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.")
parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.")
parser.add_argument(
Expand Down Expand Up @@ -563,6 +565,7 @@ def main():
flip=not args.no_flip,
stream=args.streaming,
random_crop=args.random_crop,
alpha_channel=False if args.channels == 3 else True,
)
# dataloader

Expand Down

0 comments on commit 80c1a63

Please sign in to comment.