Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen2-VL-7B Inference Code #42

Open
insafim opened this issue Nov 6, 2024 · 1 comment
Open

Qwen2-VL-7B Inference Code #42

insafim opened this issue Nov 6, 2024 · 1 comment

Comments

@insafim
Copy link

insafim commented Nov 6, 2024

Can you please provide your inference code for Qwen2-VL-7B model. I am getting only 41.3% for the standard-4 choices case.

Below is my inference code.

@insafim
Copy link
Author

insafim commented Nov 6, 2024

`import os
import sys
import json
import torch
import yaml
import re
import ast
from PIL import Image
from tqdm import tqdm
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info # pip install qwen-vl-utils
import time

Configuration

if len(sys.argv) == 2:
PROMPT = sys.argv[1]
else:
PROMPT = 'direct1'

MODEL = "Qwen/Qwen2-VL-7B-Instruct"
SETTING = 'standard'

MODEL_ID = "Qwen2-vl"

Define file paths and other constants

PROMPTS_FILE = "Prompts/prompts_mmmu-pro.yaml"
LOCAL_DATA_PATH = "Datasets/MMMU-Pro/MMMU-Pro_standard_4options.json"
IMAGE_FOLDER = "Datasets/MMMU-Pro/Images-standard"
OUTPUT_JSON_PATH = f"Results/mmmu-pro_{MODEL_ID}{SETTING}{PROMPT}.json"
MAX_RETRY = 3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Device selected: {DEVICE}")

Model and Processor Loading

model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL,
torch_dtype=torch.float16,
device_map="auto"
).to(DEVICE)

processor = AutoProcessor.from_pretrained(MODEL)

min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
processor = AutoProcessor.from_pretrained(
MODEL, min_pixels=min_pixels, max_pixels=max_pixels, torch_dtype="auto", device_map="auto"
)

Load prompt configuration

with open(PROMPTS_FILE, "r") as file:
prompt_config = yaml.safe_load(file)

Helper functions

def replace_images_tokens(input_string):
for i in range(1, 8):
question_text = f"<image {i}>"
query_text = "[image]"
input_string = input_string.replace(question_text, query_text)
return input_string

def parse_options(options):
option_letters = [chr(ord("A") + i) for i in range(len(options))]
choices_str = "\n".join([f"{option_letter}. {option}" for option_letter, option in zip(option_letters, options)])
return choices_str

def construct_prompt(doc):
question = doc["question"]
parsed_options = parse_options(ast.literal_eval(str(doc["options"])))
prompt = prompt_config[SETTING][PROMPT]
answer_handler = prompt_config[SETTING][ANSWER_HANDLER]
question = f"{question}\n{parsed_options}\n{prompt}"

return question

def mmmu_doc_to_text(doc):
question = construct_prompt(doc)
return replace_images_tokens(question)

def origin_mmmu_doc_to_visual(doc):
visual = []
print(f"Extracting images for doc id: {doc.get('id', 'Unknown')}")
for i in range(1, 8):
image_filename = doc.get(f'image_{i}')
if image_filename:
image_path = os.path.join(IMAGE_FOLDER, image_filename)
if os.path.exists(image_path):
print(f"Found image at {image_path}")
visual.append(f"{image_path}")
else:
print(f"Image {image_filename} not found at {image_path}")
return visual

def vision_mmmu_doc_to_visual(doc):
image_filename = doc.get('image')
if image_filename:
image_path = os.path.join(IMAGE_FOLDER, image_filename)
if os.path.exists(image_path):
print(f"Found image at {image_path}")
return [f"{image_path}"]
else:
print(f"Image {image_filename} not found at {image_path}")
return []

def process_prompt(data):
if SETTING == 'standard':
prompt = mmmu_doc_to_text(data)
images = origin_mmmu_doc_to_visual(data)
elif SETTING == 'vision':
prompt = prompt_config['vision']
images = vision_mmmu_doc_to_visual(data)

conversation_content = [{"type": "text", "text": prompt}]

for img_path in images:
    conversation_content.append({"type": "image", "image": img_path})

return (prompt, conversation_content)

def initialize_json(file_path):
if not os.path.exists(file_path):
print(f"Initializing new JSON file at: {file_path}")
with open(file_path, 'w', encoding='utf-8') as f:
json.dump([], f, ensure_ascii=False, indent=4)
else:
print(f"JSON file already exists at: {file_path}")

def load_existing_data(file_path):
try:
with open(file_path, 'r', encoding='utf-8') as f:
existing_data = json.load(f)
return existing_data
except Exception as e:
print(f"Error loading existing data: {e}. Starting with an empty dataset.")
return []

def update_json(file_path, new_entry):
try:
with open(file_path, 'r+', encoding='utf-8') as f:
data = json.load(f)
data.append(new_entry)
f.seek(0)
json.dump(data, f, ensure_ascii=False, indent=4)
f.truncate()
print(f"Updated JSON file with new entry id: {new_entry.get('id', 'Unknown')}")
except Exception as e:
print(f"Error updating JSON file with new entry: {e}")

def run_and_save():
initialize_json(OUTPUT_JSON_PATH)
existing_data = load_existing_data(OUTPUT_JSON_PATH)
processed_ids = {entry['id'] for entry in existing_data}

try:
    print(f"Loading dataset from: {LOCAL_DATA_PATH}")
    with open(LOCAL_DATA_PATH, 'r', encoding='utf-8') as json_file:
        dataset = json.load(json_file)
except Exception as e:
    print(f"Error loading dataset: {e}")
    sys.exit(1)

for idx, data in enumerate(tqdm(dataset, desc="Processing dataset")):
    entry_id = data.get('id', 'Unknown')
    if entry_id in processed_ids:
        print(f"Skipping already processed entry id: {entry_id}")
        continue

    prompt, conversation_content = process_prompt(data)
    messages = [{"role": "user", "content": conversation_content}]

    try:
        print(f"Preparing input for model inference with doc id: {entry_id}")
        # text = processor.apply_chat_template(messages, add_generation_prompt=True, add_vision_id=True)
        text = processor.apply_chat_template(messages, add_generation_prompt=True)
        print(f"Text after applying chat template: {text[:100]}...")

        image_inputs, video_inputs = process_vision_info(messages)
        print(f"Image inputs: {image_inputs}, Video inputs: {video_inputs}")

        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        ).to(DEVICE)
      
    except Exception as e:
        print(f"Error while processing prompt for id {entry_id}: {str(e)}")
        data['response'] = ''
        update_json(OUTPUT_JSON_PATH, data)
        continue

    decoded_output = ""
    retry_count = 0

    # while not decoded_output and retry_count < MAX_RETRY:
    try:
        
        output = model.generate(**inputs, max_new_tokens=1024, return_dict_in_generate=True, output_hidden_states=True)
        generated_tokens = output.sequences[:, inputs['input_ids'].shape[-1]:]
        decoded_output = processor.decode(generated_tokens[0], skip_special_tokens=True)
     
        if not decoded_output:
            retry_count += 1
     
    except Exception as e:
        retry_count += 1
   

    data['response'] = decoded_output if decoded_output else ''
    update_json(OUTPUT_JSON_PATH, data)

def main():
start_time = time.time() # Start timing
run_and_save()
end_time = time.time() # End timing
total_time = (end_time - start_time) / 60 # Convert to minutes
print(f"\nTotal processing time: {total_time:.2f} minutes")

if name == 'main':
main()
`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant