-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_SD_cifar.py
105 lines (99 loc) · 5.63 KB
/
gen_SD_cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torchvision.transforms as transforms
import torch
import re
import random
from tqdm import tqdm
import argparse
import os
from huggingface_hub import snapshot_download
from imbalance_cifar import IMBALANCECIFAR10
'''
def get_arguments():
parser = argparse.ArgumentParser(description='One-shot training')
# Training model hyperparameter settings
parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
parser.add_argument('--imb_type', choices=['exp', 'step',None],default='exp', type=str, help='imbalance type')
parser.add_argument('--imb_factor', default=0.1, type=float, help='imbalance factor')
parser.add_argument('--if_transform', default=False, type=bool, help='transform image or not')
parser.add_argument('--epochs', type=int, default=120, help="number of training epochs")
parser.add_argument('--momentum', type=float, default=0.9, help="SGD momentum")
parser.add_argument('--batch_size', type=int, default=128, help="batch size for training")
parser.add_argument('--weight_decay', '--wd', default=2e-4,
type=float, metavar='W')
parser.add_argument('--num_workers', type=int, default=4, help="number of workers for data loading")
# model setting
parser.add_argument('--net', type=str,
choices=['holocron_resnet18', 'holocron_resnet34', 'holocron_resnet50', "resnet50"],
default="holocron_resnet18",
help='model name to train')
parser.add_argument('--net_path', type=str,
default=None,
help='load model weight path')
# dataset setting
parser.add_argument('--data_type', type=str,
choices=["imagenet1000", "domainnet", "imagenet100", "imagenette", "imagefruit", "imageyellow", "imagesquawk"],
default="imagenette",
help='data set type')
parser.add_argument('--data_path_train', default=None, type=str, help='data path for train')
parser.add_argument('--data_path_test', default=None, type=str, help='data path for test')
parser.add_argument('--sample_data_nums', default=None, type=int, help='sample number of syn images if None samples all data')
parser.add_argument('--syn', type=int, choices=[0, 1], default=0, help='if syn dataset')
parser.add_argument('--if_blip', type=int, choices=[0, 1], default=0, help='if use instance-level syn data')
# domainnet dataset setting
parser.add_argument('--labels', nargs='+', type=int,
default=[1, 73, 11, 19, 29, 31, 290, 121, 225, 39], #['airplane', 'clock', 'axe', 'basketball', 'bicycle', 'bird', 'strawberry', 'flower', 'pizza', 'bracelet'],
help='domainnet subdataset labels')
parser.add_argument('--domains', nargs='+', type=str,
default=['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'],
help='domainent domain')
# others setting
parser.add_argument('--seed', type=int, default=0, help="random seed for reproducibility")
parser.add_argument('--exp_name', type=str, default="exp_1",
help="the name of this run")
parser.add_argument('--wandb', type=int, default=1,
help="set 1 for wandb logging")
args = parser.parse_args()
# post processing
args.syn = (args.syn==1)
args.if_blip = (args.if_blip==1)
return args
'''
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
trainset = IMBALANCECIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
print('trainset loaded')
print(len(trainset))
with open('cifar_0.01.txt','r') as f:#labels: class num of BLIPv2-generated captions, prompts: caption generated by BLIPv2 for LT data
data=f.read()
data=data.split('\n')
labels=list()
prompts=list()
for line in data:
if line:
cls_num, prompt=line.split('\t',1)
labels.append(cls_num)
prompts.append(prompt)
num_per_cls=trainset.num_list
print('num per cls: ',num_per_cls)
nums_to_gen=[max(num_per_cls)-i for i in num_per_cls]
#nums_to_gen = num_per_cls
total_nums_to_gen=sum(nums_to_gen)
print('nums to gen: ',nums_to_gen)
print('total nums to gen: ',total_nums_to_gen)
#snapshot_download(repo_id='stabilityai/stable-diffusion-2-1-base',local_dir='/root/autodl-tmp/stablediffusion2-1_pretrained',cache_dir='/root/autodl-tmp/stablediffusion2-1_pretrained',force_download=True,resume_download=False)
pipe = StableDiffusionPipeline.from_pretrained('/root/autodl-tmp/stablediffusion2-1_pretrained', torch_dtype=torch.float16,requires_safety_checker = False, safety_checker=None,local_files_only=True,)
pipe=pipe.to(device)
with tqdm(total=total_nums_to_gen,desc='generating augmentation images for LT data with SD:',leave=True, ncols=100, unit='B', unit_scale=True) as pbar:
img_size=512
for i in range(len(nums_to_gen)):
if nums_to_gen[i]!=0:
for j in range(nums_to_gen[i]):
prompt_choice=random.choice(range(sum(num_per_cls[0:i]),sum(num_per_cls[0:i+1])))
image=pipe(prompt=prompts[prompt_choice],height=img_size,width=img_size,num_inference_steps=50,num_images_per_prompt = 1,guidance_scale=2,generator=torch.Generator().manual_seed(sum(nums_to_gen)+j)).images[0]
image=image.resize((32,32))
d='/root/autodl-tmp/cifar_0.01_2.0/'+str(i)
if not os.path.exists(d):
os.makedirs(d)
image.save(d+'/'+str(j)+'.jpg')
pbar.update(1)