diff --git a/infer_vae.py b/infer_vae.py index 01022a1..f5d6efd 100644 --- a/infer_vae.py +++ b/infer_vae.py @@ -112,7 +112,9 @@ parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate.") parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") -parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA.") +parser.add_argument( + "--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA." +) parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.") parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.") parser.add_argument( diff --git a/muse_maskgit_pytorch/dataset.py b/muse_maskgit_pytorch/dataset.py index f8cba1c..7659e38 100644 --- a/muse_maskgit_pytorch/dataset.py +++ b/muse_maskgit_pytorch/dataset.py @@ -45,14 +45,20 @@ def __init__( stream=False, using_taming=False, random_crop=False, - alpha_channel=True + alpha_channel=True, ): super().__init__() self.dataset = dataset self.image_column = image_column self.stream = stream transform_list = [ - T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" and alpha_channel else img if img.mode == "RGB" and not alpha_channel else img.convert("RGB")), + T.Lambda( + lambda img: img.convert("RGBA") + if img.mode != "RGBA" and alpha_channel + else img + if img.mode == "RGB" and not alpha_channel + else img.convert("RGB") + ), T.Resize(image_size), ] @@ -201,7 +207,15 @@ def __getitem__(self, index): class LocalTextImageDataset(Dataset): def __init__( - self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False, random_crop=False, alpha_channel=False + self, + path, + image_size, + tokenizer, + flip=True, + center_crop=True, + using_taming=False, + random_crop=False, + alpha_channel=False, ): super().__init__() self.tokenizer = tokenizer @@ -231,7 +245,13 @@ def __init__( self.caption_pair.append(captions) transform_list = [ - T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" and alpha_channel else img if img.mode == "RGB" and not alpha_channel else img.convert("RGB")), + T.Lambda( + lambda img: img.convert("RGBA") + if img.mode != "RGBA" and alpha_channel + else img + if img.mode == "RGB" and not alpha_channel + else img.convert("RGB") + ), T.Resize(image_size), ] if flip: diff --git a/muse_maskgit_pytorch/vqgan_vae.py b/muse_maskgit_pytorch/vqgan_vae.py index 0887fe6..5715e65 100644 --- a/muse_maskgit_pytorch/vqgan_vae.py +++ b/muse_maskgit_pytorch/vqgan_vae.py @@ -413,7 +413,6 @@ def vgg(self): self._vgg = vgg.to(self.device) return self._vgg - @property def encoded_dim(self): return self.enc_dec.encoded_dim diff --git a/train_muse_maskgit.py b/train_muse_maskgit.py index e711cff..4ecd529 100644 --- a/train_muse_maskgit.py +++ b/train_muse_maskgit.py @@ -298,7 +298,9 @@ help="Image Size.", ) parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") -parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA.") +parser.add_argument( + "--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA." +) parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.") parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.") parser.add_argument( diff --git a/train_muse_vae.py b/train_muse_vae.py index 639edaf..655eb2a 100644 --- a/train_muse_vae.py +++ b/train_muse_vae.py @@ -222,7 +222,9 @@ ) parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") -parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA.") +parser.add_argument( + "--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA." +) parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.") parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.") parser.add_argument(