forked from victorchall/EveryDream2trainer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
caption_kosmos2.py
135 lines (109 loc) · 5.79 KB
/
caption_kosmos2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Copyright [2022-2023] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import io
import argparse
import time
import torch
from PIL import Image
from pynvml import *
from transformers import AutoProcessor, AutoModelForVision2Seq
import colorama
GROUNDING = "<grounding>"
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
def get_gpu_memory_map():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(handle)
nvmlShutdown()
return info.used/1024/1024
def remove_starting_string(a, b):
if b.startswith(a):
return b[len(a):] # Remove string A from the beginning of string B
elif b.strip().startswith(a.strip()):
return b.strip()[len(a.strip()):]
else:
return b
def main(args):
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
dtype=torch.float32
if not args.cpu:
if args.dtype == "fp16":
dtype=torch.float16
elif args.dtype == "bf16":
dtype=torch.bfloat16
elif args.dtype == "fp32":
dtype=torch.float32
model = model.to(dtype=dtype).cuda()
print(f"Using cuda, model dtype: {model.dtype}")
else:
print(f"Using cpu, model dtype: {model.dtype}")
for root, dirs, files in os.walk(args.data_root):
for file in files:
#get file extension
ext = os.path.splitext(file)[1]
if ext.lower() in SUPPORTED_EXT:
start_time = time.time()
full_file_path = os.path.join(root, file)
image = Image.open(full_file_path)
full_file_path = os.path.join(root, file)
image = Image.open(full_file_path)
inputs = processor(text=GROUNDING+args.prompt, images=image, return_tensors="pt")
with torch.cuda.amp.autocast(enabled=args.dtype != "fp32", dtype=dtype):
generated_ids = model.generate(
pixel_values=inputs["pixel_values"].cuda() if not args.cpu else inputs["pixel_values"],
input_ids=inputs["input_ids"].cuda() if not args.cpu else inputs["input_ids"],
attention_mask=inputs["attention_mask"].cuda() if not args.cpu else inputs["attention_mask"],
image_embeds=None,
image_embeds_position_mask=inputs["image_embeds_position_mask"].cuda() if not args.cpu else inputs["image_embeds_position_mask"],
use_cache=True,
max_new_tokens=args.max_new_tokens,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
processed_text, entities = processor.post_process_generation(generated_text) # remove remaining special tokens to get just the caption and entities
if not args.keep_prompt:
processed_text = remove_starting_string(args.prompt, processed_text)
print(f"File: {full_file_path}, Generated caption: {processed_text}")
name = os.path.splitext(full_file_path)[0]
if not os.path.exists(f"{name}.txt") or args.overwrite:
with open(f"{name}.txt", "w") as f:
f.write(processed_text)
if args.save_entities and (not os.path.exists(f"{name}.ent") or args.overwrite):
with open(f"{name}.ent", "w") as entities_file:
entities_file.write(str(entities))
gpu_mb_used = get_gpu_memory_map()
print(f"gpu usage: {gpu_mb_used:.1f} mb, time taken: {time.time()-start_time:.2f} seconds")
if __name__ == "__main__":
print("Kosmos-2 captioning script")
parser = argparse.ArgumentParser()
parser.description = "Kosmos-2 captioning script"
parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption")
parser.add_argument("--prompt", type=str, default="Describe this image in detail: ", help="Prompt for generating caption")
parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved")
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate")
parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file")
parser.add_argument("--overwrite", action="store_true", default=False, help="will overwrite .txt and .ent files if they exist")
parser.add_argument("--cpu", action="store_true", default=False, help="use cpu instead of cuda")
parser.add_argument("--dtype", type=str, default="fp16", help="force a different dtype if using GPU (fp16, bf16, fp32) (default: fp16)")
args = parser.parse_args()
parser.print_help()
if not args.prompt.startswith(" "):
args.prompt = " " + args.prompt
print(f"Captioning images in {args.data_root} with prompt: {args.prompt}")
print(f"Ideas for prompts:")
print(f" Describe this image in detail: (default)")
print(f" An image of ")
print(f" A two sentence description of this image:")
print()
main(args)