forked from XLabs-AI/x-flux
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
108 lines (98 loc) · 3.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
from PIL import Image
import os
from src.flux.xflux_pipeline import XFluxPipeline
def create_argparser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt", type=str, required=True,
help="The input text prompt"
)
parser.add_argument(
"--local_path", type=str, default=None,
help="Local path to the model checkpoint (Controlnet)"
)
parser.add_argument(
"--repo_id", type=str, default=None,
help="A HuggingFace repo id to download model (Controlnet)"
)
parser.add_argument(
"--name", type=str, default=None,
help="A filename to download from HuggingFace"
)
parser.add_argument(
"--lora_repo_id", type=str, default=None,
help="A HuggingFace repo id to download model (Controlnet)"
)
parser.add_argument(
"--lora_name", type=str, default=None,
help="A filename to download from HuggingFace"
)
parser.add_argument(
"--device", type=str, default="cuda",
help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)"
)
parser.add_argument(
"--offload", action='store_true', help="Offload model to CPU when not in use"
)
parser.add_argument(
"--use_lora", action='store_true', help="Load Lora model"
)
parser.add_argument(
"--use_controlnet", action='store_true', help="Load Controlnet model"
)
parser.add_argument(
"--image", type=str, default=None, help="Path to image"
)
parser.add_argument(
"--lora_weight", type=float, default=0.9, help="Lora model strength (from 0 to 1.0)"
)
parser.add_argument(
"--model_type", type=str, default="flux-dev",
choices=("flux-dev", "flux-dev-fp8", "flux-schnell"),
help="Model type to use (flux-dev, flux-dev-fp8, flux-schnell)"
)
parser.add_argument(
"--width", type=int, default=512, help="The width for generated image"
)
parser.add_argument(
"--height", type=int, default=512, help="The height for generated image"
)
parser.add_argument(
"--num_steps", type=int, default=50, help="The num_steps for diffusion process"
)
parser.add_argument(
"--guidance", type=float, default=3.5, help="The guidance for diffusion process"
)
parser.add_argument(
"--seed", type=int, default=123456789, help="A seed for reproducible inference"
)
parser.add_argument(
"--save_path", type=str, default='results', help="Path to save"
)
return parser
def main(args):
if args.image:
image = Image.open(args.image)
else:
image = None
xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload, args.seed)
if args.use_lora:
print('load lora:', args.lora_repo_id, args.lora_name)
xflux_pipeline.set_lora(None, args.lora_repo_id, args.lora_name, args.lora_weight)
if args.use_controlnet:
xflux_pipeline.set_controlnet("canny", args.local_path, args.repo_id, args.name)
result = xflux_pipeline(prompt=args.prompt,
controlnet_image=image,
width=args.width,
height=args.height,
guidance=args.guidance,
num_steps=args.num_steps,
)
if not os.path.exists(args.save_path):
os.mkdir(args.save_path)
ind = len(os.listdir(args.save_path))
result.save(os.path.join(args.save_path, f"result_{ind}.png"))
if __name__ == "__main__":
args = create_argparser().parse_args()
main(args)