-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGUI.py
441 lines (368 loc) · 14.5 KB
/
GUI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
import json
import os
##########################
# 加载语言
print("Loading configs...")
with open(os.path.join('Locales', 'config.json'), 'r', encoding='utf-8') as f:
config = json.load(f)
lang = config['language']
with open(os.path.join('Locales', f'{lang}.json'), 'r', encoding='utf-8') as f:
locale = json.load(f)
print(locale["locale_load_success"])
##########################
# 导入库
print(locale["import_libs"])
from llama_cpp import Llama
import gradio as gr
import random
import pyperclip
import re
##########################
llm = None
##########################
# 主题
theme = gr.themes.Ocean(
primary_hue="violet",
secondary_hue="indigo",
radius_size="sm",
).set(
background_fill_primary='*neutral_50',
border_color_accent='*neutral_50',
color_accent_soft='*neutral_50',
shadow_drop='none',
shadow_drop_lg='none',
shadow_inset='none',
shadow_spread='none',
shadow_spread_dark='none',
layout_gap='*spacing_xl',
checkbox_background_color='*primary_50',
checkbox_background_color_focus='*primary_200'
)
##########################
# 获取模型文件列表
def list_model_files():
models_dir = 'models'
model_files = [f for f in os.listdir(models_dir) if f.endswith('.gguf')]
return [os.path.join(models_dir, f) for f in model_files]
##########################
# 随机种
def random_seed():
return random.randint(1, 2 ** 31 - 1)
##########################
# 加载模型
def load_model(model_path, gpu, n_ctx):
global llm
llm = None
llm = Llama(model_path=model_path, n_gpu_layers=gpu, n_ctx=n_ctx)
return locale["load_model_success"].format(model_path=model_path)
##########################
# 卸载模型
def unload_model():
global llm
llm = None
return locale["unload_model_success"]
##########################
# 生成提示词
def upsampling_prompt(quality_tags, mode_tags, length_tags, tags, max_token, temp, Seed, top_p, min_p, top_k, rating,
artist, characters, meta, length, width):
aspect_ratio = round(length / width, 1)
if llm is None:
return locale["model_not_loaded"]
if mode_tags == "None" or mode_tags == "tag_to_long" or mode_tags == "tag_to_short_to_long":
output = llm(
f"quality: {quality_tags}\naspect ratio: {aspect_ratio}\ntarget: <|{length_tags}|> <|{mode_tags}|>\nrating: {rating}\nartist: {artist}\ncharacters: {characters}\nmeta: {meta}\ntag: {tags}",
# Prompt
max_tokens=max_token,
echo=True,
temperature=temp,
seed=Seed,
top_p=top_p,
min_p=min_p,
top_k=top_k
)
elif mode_tags == "long_to_tag":
output = llm(
f"quality: {quality_tags}\naspect ratio: {aspect_ratio}\ntarget: <|{length_tags}|> <|{mode_tags}|>\nrating: {rating}\nartist: {artist}\ncharacters: {characters}\nmeta: {meta}\nlong: {tags}",
# Prompt
max_tokens=max_token,
echo=True,
temperature=temp,
seed=Seed,
top_p=top_p,
min_p=min_p,
top_k=top_k
)
else:
output = llm(
f"quality: {quality_tags}\naspect ratio: {aspect_ratio}\ntarget: <|{length_tags}|> <|{mode_tags}|>\nrating: {rating}\nartist: {artist}\ncharacters: {characters}\nmeta: {meta}\nshort: {tags}",
# Prompt
max_tokens=max_token,
echo=True,
temperature=temp,
seed=Seed,
top_p=top_p,
min_p=min_p,
top_k=top_k
)
# for testing
# print(output)
return output['choices'][0]['text']
##########################
# 把 artist 放到末尾
def send_artist_to_end(text):
pattern1 = r"\nartist:.*"
# 移动到末尾
text = re.sub(pattern1, "", text) + re.search(pattern1, text).group(0)
# 去除末尾的换行
text = text.rstrip("\n")
return text
##########################
def gen_artist_str(prompt, max_token, temp, Seed, top_p, min_p, top_k):
prompt = send_artist_to_end(prompt)
output = llm(
prompt,
max_tokens=max_token,
echo=True,
temperature=temp,
seed=Seed,
top_p=top_p,
min_p=min_p,
top_k=top_k,
stop=["target"]
)
# test
# print(output)
return output['choices'][0]['text']
##########################
# 格式化输出
def extract_and_format(model_out, mode_tags):
if mode_tags == "None":
fields_to_extract = ['quality', 'artist', 'characters', 'meta', 'rating', 'tag']
elif mode_tags == "tag_to_long":
fields_to_extract = ['quality', 'artist', 'characters', 'meta', 'rating', 'tag', 'long']
elif mode_tags == "tag_to_short_to_long":
fields_to_extract = ['quality', 'artist', 'characters', 'meta', 'rating', 'tag', 'short', 'long']
elif mode_tags == "long_to_tag":
fields_to_extract = ['quality', 'artist', 'characters', 'meta', 'rating', 'long', 'tag']
elif mode_tags == "short_to_long":
fields_to_extract = ['quality', 'artist', 'characters', 'meta', 'rating', 'short', 'long']
elif mode_tags == "short_to_tag_to_long":
fields_to_extract = ['quality', 'artist', 'characters', 'meta', 'rating', 'short', 'tag', 'long']
elif mode_tags == "short_to_long_to_tag":
fields_to_extract = ['quality', 'artist', 'characters', 'meta', 'rating', 'short', 'long', 'tag']
elif mode_tags == "short_to_tag":
fields_to_extract = ['quality', 'artist', 'characters', 'meta', 'rating', 'short', 'tag']
else:
print("Error: Invalid mode_tags value")
return "Error: Invalid mode_tags value"
def extract_fields(model_output):
extracted_data = {}
for line in model_output.split('\n'):
for field in fields_to_extract:
if line.startswith(field + ':'):
extracted_data[field] = line[len(field) + 1:].strip()
return extracted_data
extracted_data = extract_fields(model_out)
formatted_output = ""
for field in fields_to_extract:
value = extracted_data.get(field, '')
if value: # Only add the field if it has a value
formatted_output += f"{value}\n\n"
# Remove the last two newline characters to ensure no extra space at the end
formatted_output = formatted_output.rstrip('\n')
return formatted_output
##########################
# 排除标签
def remove_words_by_regex(sentence, pattern):
# 移除末尾的逗号和空格(如果有的话)
patterns = pattern.rstrip(', ')
# 将传入的正则表达式字符串分割成列表
pattern_list = re.split(r',\s*', patterns)
# 使用正则表达式分割句子
words = re.split(r',\s*', sentence)
# 初始化一个空列表来存放过滤后的词
filtered_words = []
# 遍历原始单词列表
for word in words:
# 检查当前单词是否与任一正则表达式匹配
should_remove = False
for pattern in pattern_list:
if re.match(pattern, word):
should_remove = True
break
# 如果当前单词不匹配任何正则表达式,则添加到过滤后的列表中
if not should_remove:
filtered_words.append(word)
# 重新组合成字符串
result = ', '.join(filtered_words)
return result
##########################
# 更新格式化输出
def update_format_output(formatted_text, banned_tags, mode_tags):
text = extract_and_format(formatted_text, mode_tags)
if banned_tags:
formatted = remove_words_by_regex(text, banned_tags)
else:
formatted = text
format_output = gr.Textbox(value=formatted, interactive=False)
return format_output
##########################
# 复制
def copy_to_clipboard(output):
try:
pyperclip.copy(output)
gr.Info(locale["copy_success"])
except Exception as e:
raise gr.Error(locale["copy_fail"])
##########################
# 获得模型列表
print(locale["model_searching"])
available_models = list_model_files()
##########################
print(locale["gradio_launching"])
##########################
# 加载教程
with open(os.path.join('Locales', 'Tutorials', f'{lang}.md'), "r", encoding="utf-8") as tutorial:
tutorial_content = tutorial.read()
##########################
# gradio 界面
with gr.Blocks(theme=theme, title="TIPO") as demo:
with gr.Row():
with gr.Column():
# -------------------------
# 生成标签页
with gr.Tab(locale["tab_generate"]):
with gr.Row(equal_height=True):
# 种子
Seed = gr.Number(label=locale["seed"], value=-1)
Seed_random = gr.Button(locale["random_seed"])
with gr.Row():
# 长宽比
img_length = gr.Number(label=locale["img_length"], value=512, minimum=256, maximum=2048, step=1)
img_width = gr.Number(label=locale["img_width"], value=512, minimum=256, maximum=2048, step=1)
with gr.Row():
# 模式和长度标签
mode_tags = gr.Dropdown(
label=locale["mode"],
choices=["None", "tag_to_long", "long_to_tag", "short_to_long", "short_to_tag",
"tag_to_short_to_long", "short_to_tag_to_long", "short_to_long_to_tag"],
value="None"
)
length_tags = gr.Dropdown(
label=locale["length"],
choices=["very_short", "short", "long", "very_long"],
value="short"
)
with gr.Row():
# 质量和屏蔽标签
quality_tags = gr.Textbox(label=locale["quality"], value="masterpiece")
banned_tags = gr.Textbox(label=locale["banned_tags"])
with gr.Row():
# 分级和画师
rating_tags = gr.Dropdown(label=locale["rating"], choices=["safe", "sensitive", "nsfw", "explicit"],
value="safe")
artist_tags = gr.Textbox(label=locale["artist"])
with gr.Row():
# 角色和meta标签
character_tags = gr.Textbox(label=locale["character"])
meta_tags = gr.Textbox(label=locale["meta"], value="hires")
# 通用标签
tags = gr.Textbox(label=locale["general_tags"])
# -------------------------
# 画师标签页
with gr.Tab(locale["tab_artist"]):
with gr.Row(equal_height=True):
Seed2 = gr.Number(label=locale["seed"], value=-1)
Seed_random2 = gr.Button(locale["random_seed"])
# 画师标签
artist_tags_textbox = gr.Textbox(label=locale["artist"])
# -------------------------
# 设置标签页
with gr.Tab(locale["tab_settings"]):
# 模型设置
gr.Markdown(locale["model_settings"])
model_selector = gr.Dropdown(label=locale["model_select"], choices=available_models)
with gr.Row():
n_ctx = gr.Number(label="n_ctx", value=2048)
n_gpu_layers = gr.Number(label="n_gpu_layers", value=-1)
with gr.Row():
unload_btn = gr.Button(locale["model_unload"])
load_btn = gr.Button(locale["model_load"], variant="primary")
load_feedback = gr.Markdown("")
gr.Markdown(locale["generate_settings"])
# 生成设置
with gr.Row():
top_p = gr.Number(label="top_p", value=0.95)
min_p = gr.Number(label="min_p", value=0.05)
with gr.Row():
max_tokens = gr.Number(label="max_tokens", value=1024)
temprature = gr.Number(label="temperature", value=0.8)
top_k = gr.Number(label="top_k", value=60)
# -------------------------
# 教程标签页
with gr.Tab(locale["tab_tutorial"]):
gr.Markdown(tutorial_content)
# -------------------------
# 结果展示
with gr.Column():
with gr.Row():
upsampling_btn = gr.Button("TIPO!", variant="primary", scale=2)
copy_btn = gr.Button(locale["copy_to_clipboard"], scale=1)
gen_artists = gr.Button(locale["generate_artists"], scale=1)
with gr.Row():
raw_output = gr.Textbox(label=locale["result"], interactive=False)
formatted_output = gr.Textbox(label=locale["formatted_result"], interactive=False)
# 更新格式化输出
raw_output.change(update_format_output, inputs=[raw_output, banned_tags, mode_tags],
outputs=formatted_output)
artist_tags_textbox.change(update_format_output, inputs=[artist_tags_textbox, banned_tags, mode_tags],
outputs=formatted_output)
# -------------------------
# 写提示词
upsampling_btn.click(
fn=upsampling_prompt,
inputs=[quality_tags, mode_tags, length_tags, tags, max_tokens, temprature, Seed, top_p, min_p, top_k,
rating_tags, artist_tags, character_tags, meta_tags, img_length, img_width],
outputs=raw_output
)
# -------------------------
# 加载模型
load_btn.click(
fn=load_model,
inputs=[model_selector, n_gpu_layers, n_ctx],
outputs=load_feedback
)
# -------------------------
# 卸载模型
unload_btn.click(
fn=unload_model,
inputs=None,
outputs=load_feedback
)
# -------------------------
# 随机种子
Seed_random.click(
fn=random_seed,
inputs=None,
outputs=Seed
)
Seed_random2.click(
fn=random_seed,
inputs=None,
outputs=Seed2
)
# -------------------------
# 复制到剪贴板
copy_btn.click(
fn=copy_to_clipboard,
inputs=formatted_output,
outputs=None
)
# -------------------------
# 生成画师串
gen_artists.click(
fn=gen_artist_str,
inputs=[raw_output, max_tokens, temprature, Seed2, top_p, min_p, top_k],
outputs=artist_tags_textbox
)
demo.launch()