Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 12, 2023
1 parent 28d9b5c commit 113af92
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 8 deletions.
4 changes: 3 additions & 1 deletion infer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@
parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate.")
parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.")
parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.")
parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA.")
parser.add_argument(
"--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA."
)
parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.")
parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.")
parser.add_argument(
Expand Down
28 changes: 24 additions & 4 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,20 @@ def __init__(
stream=False,
using_taming=False,
random_crop=False,
alpha_channel=True
alpha_channel=True,
):
super().__init__()
self.dataset = dataset
self.image_column = image_column
self.stream = stream
transform_list = [
T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" and alpha_channel else img if img.mode == "RGB" and not alpha_channel else img.convert("RGB")),
T.Lambda(
lambda img: img.convert("RGBA")
if img.mode != "RGBA" and alpha_channel
else img
if img.mode == "RGB" and not alpha_channel
else img.convert("RGB")
),
T.Resize(image_size),
]

Expand Down Expand Up @@ -201,7 +207,15 @@ def __getitem__(self, index):

class LocalTextImageDataset(Dataset):
def __init__(
self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False, random_crop=False, alpha_channel=False
self,
path,
image_size,
tokenizer,
flip=True,
center_crop=True,
using_taming=False,
random_crop=False,
alpha_channel=False,
):
super().__init__()
self.tokenizer = tokenizer
Expand Down Expand Up @@ -231,7 +245,13 @@ def __init__(
self.caption_pair.append(captions)

transform_list = [
T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" and alpha_channel else img if img.mode == "RGB" and not alpha_channel else img.convert("RGB")),
T.Lambda(
lambda img: img.convert("RGBA")
if img.mode != "RGBA" and alpha_channel
else img
if img.mode == "RGB" and not alpha_channel
else img.convert("RGB")
),
T.Resize(image_size),
]
if flip:
Expand Down
1 change: 0 additions & 1 deletion muse_maskgit_pytorch/vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ def vgg(self):
self._vgg = vgg.to(self.device)
return self._vgg


@property
def encoded_dim(self):
return self.enc_dec.encoded_dim
Expand Down
4 changes: 3 additions & 1 deletion train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,9 @@
help="Image Size.",
)
parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.")
parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA.")
parser.add_argument(
"--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA."
)
parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.")
parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.")
parser.add_argument(
Expand Down
4 changes: 3 additions & 1 deletion train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@
)
parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.")
parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.")
parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA.")
parser.add_argument(
"--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA."
)
parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.")
parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.")
parser.add_argument(
Expand Down

0 comments on commit 113af92

Please sign in to comment.