Skip to content

apple/ml-gbc

Repository files navigation

Graph Based Captioning (GBC)

Paper Datasets

Official repository for the paper Graph-Based Captioning: Enhancing Visual Descriptions by Interconnecting Region Captions.

Overview

Graph-based captioning interconnects region captions to form an integral, structured, and fine-grained description of an image. This codebase contains code for the following purposes.

🚀 Get Started

To begin using the library, follow these steps to set up your environment. Depending on your intended use, you may need to install optional dependencies or perform additional setup.

Step 1: Set Up a Conda Environment

First, create and activate a new Conda environment with Python >= 3.10:

conda create -n gbc python=3.10 -y
conda activate gbc

Step 2: Install the Library (with Optional Dependencies)

Install the library. If needed, include optional dependencies tailored to specific features:

python -m pip install ".[optional_dependencies]"

Step 3: Additional Setup

  • Please refer to GBC Captioning for additional steps that are needed for running the captioning pipeline with specific models.
  • The GBC viewer is standalone and does not require the base gbc library but has separate installation steps.

⚙️ GBC Data Loading and Processing

We define two Python classes for our GBC annotations: GbcGraph and GbcGraphFull.

While the released dataset adheres to the GbcGraph class structure, all functionalities are implemented in GbcGraphFull. Therefore, GbcGraphFull is the recommended class for usage. You can load the released data into GbcGraphFull as follows.

from datasets import load_dataset
from gbc.data import GbcGraph, GbcGraphFull

ds = load_dataset("graph-based-captions/GBC1M", split="train") # or GBC10M
gbc_graphs = []
for record in ds.select(range(100)):
    gbc_graph = GbcGraphFull.model_validate(record)
    # # Equivalently
    # gbc_graph = GbcGraph.model_validate(gbc_graph)
    # gbc_graph = GbcGraphFull.from_gbc_graph(gbc_graph)
    gbc_graphs.append(gbc_graph)

To load data from local files:

from gbc.utils import load_list_from_file
from gbc.data import GbcGraphFull

gbc_graphs = load_list_from_file("data/gbc/wiki/wiki_gbc_graphs.jsonl", class_type=GbcGraphFull)

Data Processing

You can leverage the data processing script to process GBC graphs locally. To do this, specify a transform function in the configuration file. The transform function should take a GbcGraph or GbcGraphFull as input and return either a dictionary or a pydantic object. We provide several pre-defined transform functions in src/gbc/processing/data_transforms.

For transform functions that compute CLIP or toxicity scores, ensure you have the processing optional dependency installed:

python -m pip install ".[processing]"

Data Processing Examples

Click to expand

File Format Conversion

When no configuration file is provided, the script can be used to convert files between .parquet, .jsonl, and .json formats. For example, the following command converts a .jsonl file to .parquet:

python scripts/processing/process_gbc.py \
    --input_paths data/gbc/wiki/wiki_gbc_graphs.jsonl \
    --input_formats .jsonl \
    --save_format .parquet \
    --save_dir tests/outputs/processing/conversion/

The --input_paths argument accepts one or more file or folder paths. If a folder is specified, all files within the folder and its subfolders matching the specified --input_formats will be processed.

Graph to Text

The following converts graphs to dictionaries, where the text field contains specifically formatted text that accounts for the graph structure, and the image_path and image_url fields store the corresponding image path and URL.

python scripts/processing/process_gbc.py \
    --configs configs/processing/to_structured_text.yaml

Note: We directly specify all the arguments such as --input_paths, --input_formats, --save_format, and --save_dir in the configuration file, so there is no need to provide them as command-line arguments.

CLIP/Toxicity Score Computation

To compute CLIP scores:

python scripts/processing/process_gbc.py \
    --configs configs/processing/compute_clip_scores.yaml

To compute toxicity scores:

python scripts/processing/process_gbc.py \
    --configs configs/processing/compute_toxicity_scores.yaml

To compute both CLIP and toxicity scores:

python scripts/processing/process_gbc.py \
    --configs configs/processing/compute_all_scores.yaml

Filtering

Basic filtering with vertex/caption type, bounding box size etc.:

python scripts/processing/process_gbc.py \
    --configs configs/processing/relation_composition_filtering.yaml

More advanced filtering based on CLIP score:

python scripts/processing/process_gbc.py \
    --configs configs/processing/clip_filtering.yaml

🎞️ GBC Visualization

We provide a standalone viewer for exploring GBC-annotated data interactively. You can use it with our released datasets or your own data processed through our pipeline. Please see the viewer folder for more details.

🎨 GBC Text-to-Image

We introduce methods to generate GBC from simple text prompt and to generate image from GBC. To enable these features, please install with the t2i optional dependency.

python -m pip install ".[t2i]"

Text-to-GBC

This process relies on our released 200M language model. You can run the prompt generation script using the following command, with the generated graphs saved as GbcGraph in .json file (you can also save them in .parquet or .jsonl files).

python scripts/generation/t2gbc.py \
    --configs configs/generation/t2gbc_default.yaml \
    --prompt_file prompts/t2i/t2gbc_seed_with_entity_specification.yaml \
    --save_file tests/outputs/generation/t2gbc/gbc_prompt_gen.json \
    --prompt_gen_model_name_or_path graph-based-captions/GBC10M-PromptGen-200M

Prompt Specification

The prompts can be provided in two formats: .txt or .yaml

Moreover, you can use brackets to indicate the objects from the seed prompts that need further description. For example, in prompts/t2i/t2gbc_seed_with_entity_specification.yaml, we have

- A cozy library room with a large wooden [bookshelf], a [leather armchair], and a small reading table with an old [lamp].
- A [turtle] sunbathing on a [log] in a quiet pond, with lily pads floating on the water.
- A [frog] in a mystical forest filled with oversized mushrooms.
- A steampunk-inspired workshop with gears on the walls and a [mechanical cat].

This requires the model to create three children from the root node with labels bookshelf, leather armchair, and lamp for the first prompt, two children from the root node with labels turtle and log for the second prompt, and so on.

GBC-to-Image

Our training-free approach allows to sample images from GBC prompts using SDXL. For this you need to choose an appropriate configuration file and run

python scripts/generation/gbc2i.py \
    --configs path/to/config.yaml \
    --prompt_files prompt_file_1 prompt_file_2 \
    --neg_prompt_files neg_prompt_file_1 neg_prompt_file_2 \
    --save_dir output_folder

Prompt Specification

The prompts can be provided in two formats: .yaml files (typically user-specified) and GBC annotations (.parquet, .jsonl, .json) produced by our tools.

Examples

We explain below which configuration to use for each of the algorithm shown in the figure. We use the default negative prompt file prompts/t2i/neg_default.yaml.

Click to expand
Vanilla SDXL Sampling (Column 5)
python scripts/generation/gbc2i.py \
    --configs configs/generation/gbc2i/sampling_base.yaml \
    --prompt_files prompts/t2i/t2gbc_seed.yaml

Note: We simply sample from plain text here. There is not a specific implementation for sampling from concatenation of GBC prompts.

Sampling from Prompts and Bounding Boxes (Column 2)
python scripts/generation/gbc2i.py \
    --configs configs/generation/gbc2i/sampling_region_base.yaml \
    --prompt_files prompts/t2i/banana_apple.yaml prompts/t2i/living_room.yaml

With IP Adapter

python scripts/generation/gbc2i.py \
    --configs configs/generation/gbc2i/sampling_region_base_ipa.yaml \
    --prompt_files prompts/t2i/dog_cat_ref_image.yaml
Sampling from Prompts, Bounding Boxes, and Graph; Prompts Encoded with Context (Column 3)
python scripts/generation/gbc2i.py \
    --configs configs/generation/gbc2i/sampling_region_gbc_encode_with_context.yaml \
    --prompt_files prompts/t2i/banana_apple.yaml prompts/t2i/living_room.yaml

With IP Adapter

python scripts/generation/gbc2i.py \
    --configs configs/generation/gbc2i/sampling_region_gbc_encode_with_context_ipa.yaml \
    --prompt_files prompts/t2i/dog_cat_ref_image.yaml
Sampling from Prompts, Bounding Boxes, and Graph; Prompts Encoded without Context (Column 4)
python scripts/generation/gbc2i.py \
    --configs configs/generation/gbc2i/sampling_region_gbc_encode_without_context.yaml \
    --prompt_files prompts/t2i/banana_apple.yaml prompts/t2i/living_room.yaml

With IP Adapter

python scripts/generation/gbc2i.py \
    --configs configs/generation/gbc2i/sampling_region_gbc_encode_without_context_ipa.yaml \
    --prompt_files prompts/t2i/dog_cat_ref_image.yaml
Sampling from Prompts and Graph (Column 1)

Prompts encoded with context

python scripts/generation/gbc2i.py \
    --configs configs/generation/gbc2i/sampling_gbc_encode_with_context.yaml \
    --prompt_files prompts/t2i/banana_apple_graph_only.yaml prompts/t2i/living_room_graph_only.yaml

Prompts encoded without context

python scripts/generation/gbc2i.py \
    --configs configs/generation/gbc2i/sampling_gbc_encode_without_context.yaml \
    --prompt_files prompts/t2i/banana_apple_graph_only.yaml prompts/t2i/living_room_graph_only.yaml

Prompts encoded with context + IP Adapter

python scripts/generation/gbc2i.py \
    --configs configs/generation/gbc2i/sampling_gbc_encode_with_context_ipa.yaml \
    --prompt_files prompts/t2i/dog_cat_ref_image_graph_only.yaml prompts/t2i/living_room_graph_only.yaml

Note: This algorithm is only expected to work when the underlying graph is star graph. Moreover, we use prompt files which assign empty bounding boxes to all but the first prompt. This ensures that only the first prompt is used in the first phase (in fact, the first phase runs the sampling algorithm that uses both prompts and bounding boxes).

Text-to-Image with GBC as middleware

You can use generated prompts to generate images as follows. In this case, you can also specify graph_transform in configuration file (use for example those defined in src/gbc/processing/data_transforms).

python scripts/generation/gbc2i.py \                                                                                        
    --configs configs/generation/gbc2i/sampling_region_gbc_encode_with_context.yaml \
              configs/generation/graph_transform_ex.yaml \
    --prompt_files data/gbc/prompt_gen/library_turtle_frog_steamponk.json

Alternatively, the following allows to run the two phases together.

python scripts/generation/t2gbc2i.py \
    --configs configs/generation/t2gbc_default.yaml \
              configs/generation/gbc2i/sampling_region_gbc_encode_with_context.yaml \
              configs/generation/graph_transform_ex.yaml \
    --prompt_file prompts/t2i/t2gbc_seed_with_entity_specification.yaml \
    --neg_prompt_file prompts/t2i/neg_default.yaml \
    --save_dir tests/outputs/generation/t2gbc2i/sdxl-region-gbc-with-context

Once the GBC and images are generated, you can visualize them with our viewer as follows (assume that you have installed and built the viewer):

python viewer/server/api.py \
    --path tests/outputs/generation/t2gbc2i/sdxl-region-gbc-with-context \
    --frontend_path viewer/dist

🖋️ GBC Captioning

To generate GBC annotations for your own images, install the captioning optional dependency.

python -m pip install ".[captioning]"

Moreover, depending on the models used in the captioning process, you may need additional steps to download pre-trained models or install extra dependencies. By default, the code uses Pixtral for MLLM queries and GroundingDINO as detection model.

Click to expand

Use Pixtral

We use vllm to load neuralmagic/pixtral-12b-FP8-dynamic. There is no extra installation step provided that we already include vllm in the dependencies. However, you may need to adjust gpu_memory_utilization in the config file to match your system's memory capacity.

Use Grounding DINO

We leverage the GroundingDINO implementation from Huggingface. Therefore, no additional preparation steps are needed for this model. However, it does require the transformers version to be >= 4.40.

Use LLaVA 1.6

To use LLaVA-1.6 for captioning, run the following script to set up the necessary dependencies and download the required models:

bash scripts/setup/setup_llava_query.sh

Note: This assumes that GPU is available and installs llama-cpp-python with CUDA support. You may need to modify the installation command depending on your specific environment.

Use YOLO-World

Our implementation is based on the original YOLO-World repository. The following script clones the repository, installs the necessary dependencies, and downloads the required models.

bash scripts/setup/setup_yolo_world_detection.sh

Running from the CLI

You can then use the captioning script as follows:

python scripts/captioning/run_gbc_captioning.py \
    --img_paths img_folder_1 img_folder_2 img_1.png img_2.jpg \
    --save_dir output_folder \
    --config_file configs/captioning/default.yaml \
    --save_frequency 50 \
    --save_images \
    --attempt_resume \
    --save_format .jsonl .parquet

The above command performs GBC captioning for img_1.png, img_2.jpg, and all the images found recursively under img_folder_1 and img_folder_2. The results are then saved to output_folder in both .jsonl and .parquet formats, each containing all annotations.

Arguments

  • --img_paths: Specify the image files and folders to be captioned.
  • --save_dir: Directory where the output will be saved.
  • --config_file: Path to the configuration file.
  • --save_frequency: The frequency in terms of number of completed actions for saving intermediate artifacts. This helps in resuming the process if interrupted.
  • --save_images: Save all the input images used in each query.
  • --attempt_resume: Attempt to resume the captioning process if it was interrupted.
  • --save_format: Specify the output formats (.json, .jsonl, or .parquet).

Note

  • The --save_images flag is primarily for debugging and understanding the process. It is not recommended for captioning a large number of images as it will save all input images used in each query.
  • There are additional arguments available that are not listed here. Please refer to the script itself or use the --help option for a complete list of arguments.

Configuration File

The captioning process can utilize an optional configuration file to define the captioning details. By leveraging Hydra, you can specify the exact implementation for each query type---Image, Entity, Relation, Composition, and Detection(s) alike---under the queries section. Refer to configs/captioning/default.yaml for an example. You can also provide custom implementations for each query type, as long as they subclass Action and adhere to the required method signature.

*The pipeline_config section is ignored when running the captioning script.

Running with Python

The captioning actions and pipelines are all defined in gbc.captioning. It suffices to instantiate a GbcPipeline object.

from omegaconf import OmegaConf
from objprint import op

from gbc.utils import save_list_to_file
from gbc.captioning import GbcPipeline

config = OmegaConf.load("configs/captioning/default.yaml")
gbc_pipeline = GbcPipeline.from_config(config)

img_file_1 = "data/images/wiki/Eiffel_tower_0.jpg"
img_file_2 = "data/images/wiki/Eiffel_tower_1.jpg"

# Perform captioning on a single image
gbc = gbc_pipeline.run_gbc_captioning(img_file_1)
# Pretty print the GBC graph
op(gbc[0].model_dump())

# Perform captioning on multiple images
gbcs = gbc_pipeline.run_gbc_captioning([img_file_1, img_file_2])
# Save the GBC graphs, can save as json, jsonl, or parquet
save_list_to_file(gbcs, "tests/outputs/captioning/gbc_eiffel_tower.json")

In this case, further configuration detail about the pipeline are to be listed in the pipeline_config section in the configuration file unless you provide them separately as input of GbcPipeline.from_config.

Alternatively, you can use the functional interface.

from omegaconf import OmegaConf
from objprint import op

from gbc.utils import save_list_to_file
from gbc.captioning import run_gbc_captioning

config = OmegaConf.load("configs/captioning/default.yaml")

img_file_1 = "data/images/wiki/Eiffel_tower_0.jpg"
img_file_2 = "data/images/wiki/Eiffel_tower_1.jpg"

# Perform captioning on a single image
gbc = run_gbc_captioning(img_file_1, config, include_relation_query=False)
# Pretty print the GBC graph
op(gbc[0].model_dump())

# Perform captioning on multiple images
gbcs = run_gbc_captioning(
    [img_file_1, img_file_2], config, batch_query=True, batch_size=8
)
# Save the GBC graphs, can save as json, jsonl, or parquet
save_list_to_file(gbcs, "tests/outputs/captioning/gbc_batch_eiffel_tower.json")

License

This project is distributed under the following terms:

Acknowledgement

Our codebase builds upon several open-source projects. We highlight some of the most important ones below:

Citation

@article{GBC2024,
  title={Graph-Based Captioning: Enhancing Visual Descriptions by Interconnecting Region Captions},
  author={Yu-Guan Hsieh and Cheng-Yu Hsieh and Shih-Ying Yeh and Louis Béthune and Hadi Pouransari and Pavan Kumar Anasosalu Vasu and Chun-Liang Li and Ranjay Krishna and Oncel Tuzel and Marco Cuturi},
  journal={arXiv preprint arXiv:2407.06723},
  year={2024}
}

About

No description, website, or topics provided.

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published