Skip to content

Commit

Permalink
feat: image-text (#3)
Browse files Browse the repository at this point in the history
* fix: loading of in21k vit

* fix: config for gh200

* fix: hostfile

* fix: hostfile

* fix: requirements

* fix: multinode

* fix: impor tos

* fix: model name arg, dataset size, text query

* chore: ignore ids to keep

* feat: build index and save embeddings + rankings

* fix: update for cp medium

* fix: update build_map_index_filter

* fix: save faster, remove expensive gather

* fix: option to save

* fix: reqs

* fix: protobuf

* fix: beam req

* fix: beam req

* fix: faster multi-gpu index filteirng

* feat: filter uuids from json files

* fix: print statement

* fix: remove duplicates

* fix: remove unused

* feat: precompute text embeddings

* fix: remove unused

* fix: updates

* fix: download and process

* fix: working now

* feat: cross attention/coca-esque model

* fix: reorder tar so npy near others

* feat: add back sampling with replacement (need to test)

* feat: train with pretrained

* fix: precomputed

* feat: cpu offload faster than regular opt w grad checkp

* fix: can load and resample now!

* feat: sugarcrepe eval

* fix: prefix for imagenet

* fix: add prefix in trainer for eval

* feat: datacomp evals for contrators models

* fix: prefix + print after evals

* fix: if empty set to none

* fix: dataset-size pass via cli

* fix: fairness eval

* fix: allow null path for imagenet for testing

* feat: mlm + contrastive loss

* fix: imagenet fixes

* fix: deepspeed config

* fix: imagenet eval

* feat: three towers image <-> text, text <-> frozen

* fix: eval steps

* fix: hf model updates

* fix: vit pos embed

* feat: three towers current

* fix: multinode fixes

* fix: global rank in multinode

* fix: progbar only global rank 0

* feat: higher lr

* fix: eval strategy epochs logging fix

* feat: no clamp logits config

* feat: 3 epoch training

* feat: update hostfile

* fix: 10 epochs

* fix: update hostfile

* feat: upload embs to atlas

* feat: dino v1

* fix: grad check

* fix: clip model

* feat: 32k vit-l

* fix: update hostfile

* fix: workers

* fix: more logging

* fix: no wandb for now

* fix: try smaller vit

* fix: try more ds stuff

* fix: try  openclip loss

* fix: remove unneeded print

* fix: test clip loss

* fix: 32k run

* fix: are evals broken?

* fix: 16k testing

* fix: evals

* fix: evals

* chore: logging

* fix: remove prints

* feat: config

* fix: remove rng, trust openclip

* fix: idk?

* fix: path

* fix: rank

* feat: ok now working L14

* feat: 32k higher lr exp

* feat: fb vit mae

* feat: mae train

* fix: map mae

* fix: sp

* fix: batch size

* feat: 10epoch 65k

* feat: higher lr

* feat: no wd

* feat: long train

* feat: 81k bs

* feat: 3 epoch 65k

* feat: 10 epoch

* fix: large 3 epoch train

* fix: workers

* fix: model utils loading

* fix: dataloader for datacomp1b

* fix: remove pdb

* fix: workers

* feat: dfn 2b

* fix: bs

* fix: bs

* fix: wandb

* fix: imagenet workers

* feat: try unidirectional

* fix: path for old h100

* fix: map

* fix: lets try this again

* fix: try fusing

* fix: bad code

* fix: 32k map fix

* fix: bs and default get for dataset

* fix: fused

* fix; dumb

* fix: try this

* feat: pos embed with swiglu gated

* fix: patch size

* fix: runs now

* fix: back to mlp

* fix: stage 3?

* fix: try again

* fix: remove pos embed

* fix: wtf

* feat: mean pool test again?

* feat: augments

* fix: try no checkpointing

* feat: 3 epoch augmentation train

* fix: no randaugment

* fix: dataset size

* feat: 65k run with augs

* fix: imagenet path

* feat: try resume training multinode

* fix: hostfile

* fix: no flip for this train

* fix: imagenet

* refactor: remove unused

* refactor: rename text_encoder -> nomic_encoder

* refactor: remove captioner

* chore: bump pydantic >= 2.0.0

* feat: eval for clip models

* feat: v1.5 config

* fix: hf code

* refactor: move hf tests to separate

* chore: remove unused

* refactor: remove

* refactor: unused code

* refactor: not used

* fix: remove unused

* refactor: remove xattn

* refactor: remove xattn

* fix: try to resume

* fix: v1.5

* fix: remove unused import

* fix: remove ema

* fix: remove ema

* fix: instructions

* feat: tracing code

* feat: add stacks

* feat: export_stacks=True

* fix: with_stack

* fix: tensorboard profiling (kind of) working

* fix: don't profile, test full thing

* feat: moar batch

* feat: train

* refactor: clean up code

* feat: download data

* fix: pydantic, workers crashing

* fix: prefix

* chore: ignore data folder

* feat: loadable hf model

* fix: map pooling bug

* fix: comment old pooling

* feat: flickr eval running

* feat: flickr to config

* feat: flickr eval train

* fix: flickr eval doesn't hang

* feat: biencoder test

* fix: enforce no dynamic ntk

* feat: unidirectional

* feat: base timm models

* fix: simplify vit pos_embed

* fix: cls token confusion

* feat: timm dinov2 with registers

* wip vit rotary

* feat: yolo 65k scratch vit

* fix: hostfile

* fix: revert back to bidirectional

* fix: spelling

* fix: path

* fix: wandb

* fix: shards

* fix: reqs

* feat: eva-style models, timm vit-base

* fix: timm vit-b 224 image

* feat: timm vit-b-16 first experiment

* fix: no flip

* feat: eva02 vit base

* feat: pooling heads from timm vit

* feat: add augreg vits as option

* fix: remove pooling heads

* fix: dumb renaming of model so eva loads with autoconfig

* feat: eva config for training

* fix: model loading

* feat: 65k eva 3 epoch train

* feat: map no clamp

* fix: hostfile

* fix: reduce workers

* fix: no clamp

* fix: config

* feat: v1.5 train

* fix: hostfile + config

* fix: config for lower lr

* fix: hamming

* fix: train

* feat: hf vision model code

* fix: hostfile

* fix: path

* refactor: clean up code base

* refactor: rename

* fix: remove hostfile

* refactor: remove sugarcrepe

* style: black and isort

* docs: readme and config fixes

* fix: trainers, come back later
  • Loading branch information
zanussbaum authored Jun 5, 2024
1 parent c545be2 commit a547553
Show file tree
Hide file tree
Showing 84 changed files with 9,902 additions and 1,844 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
data/
**/ids_to_keep_*.json
*counts.json*
medi*.json
nq*
Expand Down
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
- Huggingface Support for easy loading of common models (Pythia/GPTNeoX, BERT, etc.)
- Masked Language Modeling (MLM) Pretraining
- [Matryoshka Representation Learning](https://arxiv.org/abs/2205.13147) for flexible embedding sizes
- [CLIP](https://arxiv.org/abs/2103.00020) and [LiT](https://arxiv.org/abs/2111.07991) style contrastive learning
- Support for loading popular ViT (e.g. [timm](https://huggingface.co/timm)) models

## Research

* [Nomic Embed: Training a Reproducible Long Context Text Embedder](https://arxiv.org/abs/2402.01613) by Zach Nussbaum, Jack Morris, Andrei Mulyar, and Brandon Duderstadt
* [Nomic Embed: Training a Reproducible Long Context Text Embedder](https://arxiv.org/abs/2402.01613) by Zach Nussbaum, Jack Morris, Andriy Mulyar, and Brandon Duderstadt

## Getting Started and Requirements

Expand All @@ -41,7 +43,7 @@ pip3 install torch torchvision torchaudio
Install wheel, packaging, ninja for Flash Attention (so the builds don't take too long)

```bash
pip install wheel packaging ninja
pip install wheel packaging ninja setuptools
```

Install Flash Attention and the custom kernels
Expand Down Expand Up @@ -141,6 +143,17 @@ This will train a bert model on all ~200M examples. To change the dataset, you c

To finetune `nomic-bert-embed-v1-unsupervised`, update the config to `configs/train/contrastive_finetune.yaml`.


## Training `nomic-embed-vision-v1.5`

To align a vision model, you will need to curate a large image-text dataset. More details can be found [here](https://github.com/rom1504/img2dataset).

To align `nomic-embed-vision-v1.5` with `nomic-embed-text-v1.5`, you can run the following command:

```bash
deepspeed train.py --deepspeed_config=configs/deepspeed/image_text.json --config=configs/train/nomic_embed_vision_v1.5.yaml --dtype=bf16
```

### Generating Your Own Data

To generate your own data for any step of the pipeline, you can use the provided scripts in `scripts/text`.
Expand Down
24 changes: 20 additions & 4 deletions convert_to_hf.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,40 @@
from contrastors.models.huggingface import NomicBertForPreTraining, NomicBertConfig
from contrastors.models.biencoder import BiEncoder, BiEncoderConfig
from argparse import ArgumentParser

from contrastors.models.biencoder import BiEncoder, BiEncoderConfig
from contrastors.models.dual_encoder import DualEncoder, DualEncoderConfig
from contrastors.models.huggingface import NomicBertConfig, NomicBertForPreTraining, NomicVisionModel


def parse_args():
parser = ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--private", action="store_true")
parser.add_argument("--biencoder", action="store_true")
parser.add_argument("--vision", action="store_true")
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
if args.biencoder:
config = BiEncoderConfig.from_pretrained(args.ckpt_path)
model = BiEncoder.from_pretrained(args.ckpt_path, config=config)
model = model.trunk
elif args.vision:
NomicBertConfig.register_for_auto_class()
NomicVisionModel.register_for_auto_class("AutoModel")
config = DualEncoderConfig.from_pretrained(args.ckpt_path)
model = DualEncoder.from_pretrained(args.ckpt_path, config=config)
vision = model.vision
hf_config = NomicBertConfig(**model.vision.trunk.config.to_dict())
model = NomicVisionModel(hf_config)

state_dict = vision.state_dict()
state_dict = {k.replace("trunk.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
else:
config = NomicBertConfig.from_pretrained(args.ckpt_path)
model = NomicBertForPreTraining.from_pretrained(args.ckpt_path, config=config)
model.push_to_hub(args.model_name, private=args.private)

model.push_to_hub(args.model_name, private=args.private, use_temp_dir=False)
204 changes: 178 additions & 26 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,27 +1,179 @@
datasets>=2.16.0
nomic>3.0.0
webdataset
s3fs>=2023.10.0
boto3
google-cloud-storage
wandb
torchmetrics
transformers>=4.34.0
einops
sentencepiece
deepspeed
wheel
packaging
tabulate
av
evaluate
scipy
pydantic<2.0.0
matplotlib
seaborn
tiktoken
openai
mteb
beir
tabulate
accelerate==0.30.1
aiobotocore==2.12.3
aiohttp==3.9.5
aioitertools==0.11.0
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.3.0
attrs==23.2.0
av==12.0.0
blis==0.7.11
boto3==1.34.69
botocore==1.34.69
braceexpand==0.1.7
cachetools==5.3.3
catalogue==2.0.10
certifi==2024.2.2
charset-normalizer==3.3.2
click==8.1.7
clip-benchmark==1.6.1
cloudpathlib==0.18.1
colorama==0.4.6
confection==0.1.4
contourpy==1.2.1
cycler==0.12.1
cymem==2.0.8
datasets==2.19.1
deepspeed==0.14.2
dill==0.3.8
distro==1.9.0
docker-pycreds==0.4.0
einops==0.8.0
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.6.0/en_core_web_sm-3.6.0-py3-none-any.whl#sha256=83276fc78a70045627144786b52e1f2728ad5e29e5e43916ec37ea9c26a11212
eval_type_backport==0.2.0
evaluate==0.4.2
filelock==3.14.0
flash-attn==2.5.8
fonttools==4.51.0
frozenlist==1.4.1
fsspec==2024.3.1
ftfy==6.2.0
gitdb==4.0.11
GitPython==3.1.43
google-api-core==2.19.0
google-auth==2.29.0
google-cloud-core==2.4.1
google-cloud-storage==2.16.0
google-crc32c==1.5.0
google-resumable-media==2.7.0
googleapis-common-protos==1.63.0
h11==0.14.0
hjson==3.1.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.23.0
idna==3.7
iniconfig==2.0.0
Jinja2==3.1.4
jmespath==1.0.1
joblib==1.4.2
jsonlines==4.0.0
kiwisolver==1.4.5
langcodes==3.4.0
language_data==1.2.0
lightning-utilities==0.11.2
littleutils==0.2.2
loguru==0.7.2
marisa-trie==1.1.1
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.8.4
mdurl==0.1.2
mpmath==1.3.0
mteb==1.8.11
multidict==6.0.5
multiprocess==0.70.15
murmurhash==1.0.10
networkx==3.3
ninja==1.11.1.1
nltk==3.8.1
nomic==3.0.27
numpy==1.24.2
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
ogb==1.3.6
onnx==1.16.0
onnxconverter-common==1.14.0
open-clip-torch==2.24.0
openai==1.28.1
outdated==0.2.2
packaging==24.0
pandas==2.2.2
pathlib_abc==0.1.1
pathy==0.11.0
peft==0.4.0
pillow==10.2.0
platformdirs==4.2.1
pluggy==1.5.0
polars==0.20.25
preshed==3.0.9
pretty-errors==1.2.25
proto-plus==1.23.0
protobuf==3.20.2
psutil==5.9.8
py-cpuinfo==9.0.0
pyarrow==16.0.0
pyarrow-hotfix==0.6
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycocoevalcap==1.2
pycocotools==2.0.7
pydantic==2.7.1
pydantic_core==2.18.2
Pygments==2.18.0
PyJWT==2.8.0
pynvml==11.5.0
pyparsing==3.1.2
pytest==8.2.0
python-dateutil==2.9.0.post0
pytrec-eval-terrier==0.5.6
pytz==2024.1
PyYAML==6.0.1
regex==2024.5.10
requests==2.31.0
rich==13.7.1
rsa==4.9
s3fs==2024.3.1
s3transfer==0.10.1
safetensors==0.4.3
scikit-learn==1.4.2
scipy==1.13.0
seaborn==0.13.2
sentence-transformers==2.7.0
sentencepiece==0.2.0
sentry-sdk==2.1.1
setproctitle==1.3.3
six==1.16.0
smart-open==6.4.0
smmap==5.0.1
sniffio==1.3.1
spacy==3.6.1
spacy-legacy==3.0.12
spacy-loggers==1.0.4
srsly==2.4.8
sympy==1.12
tabulate==0.9.0
thinc==8.1.12
threadpoolctl==3.5.0
tiktoken==0.6.0
timm==1.0.3
tokenizers==0.19.1
torch==2.3.0
torchaudio==2.3.0
torchmetrics==1.4.0
torchvision==0.18.0
tqdm==4.66.4
transformers==4.40.2
triton==2.3.0
typer==0.9.4
typing_extensions==4.11.0
tzdata==2024.1
urllib3==2.2.1
wandb==0.17.0
wasabi==1.1.2
wcwidth==0.2.13
webdataset==0.2.86
wilds==2.0.0
wrapt==1.16.0
xxhash==3.4.1
yarl==1.9.4
72 changes: 72 additions & 0 deletions scripts/image/dataset_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import concurrent.futures
import json
import multiprocessing as mp
from argparse import ArgumentParser
from pathlib import Path

import braceexpand
import fsspec
import pyarrow.parquet as pq
from tqdm import tqdm


def get_dataset_size(shard):
fs = fsspec.filesystem('s3')
try:
with fs.open(shard.replace(".tar", "_stats.json"), "r") as f:
stats = json.load(f)
shard_size = int(stats["successes"])

except Exception as e:
print(f"Error reading {shard}: {e}")
shard_size = 0

return shard_size


if __name__ == "__main__":
parser = ArgumentParser(description="Get the size of a dataset")
parser.add_argument(
"--shards",
type=str,
help="Path to the shards",
default="s3://commonpool-medium/shards/{00000000..00012895}.tar",
)
parser.add_argument("--workers", type=int, help="Number of workers", default=mp.cpu_count())
args = parser.parse_args()
shards = args.shards

shards_list = braceexpand.braceexpand(shards)
shards_list = list(shards_list)

num_shards = len(shards_list)
print(num_shards)

pbar = tqdm(total=num_shards)

total_size = 0
path2size = {}
if args.workers == 1:
for shard in shards_list:
shard_size = get_dataset_size(shard)
path2size[Path(shard).name] = shard_size
total_size += shard_size
pbar.update(1)
else:
with concurrent.futures.ProcessPoolExecutor(max_workers=mp.cpu_count()) as executor:
future2shard = {executor.submit(get_dataset_size, shard): shard for shard in shards_list}

for future in concurrent.futures.as_completed(future2shard):
shard = future2shard[future]
try:
shard_size = future.result()
path2size[Path(shard).name] = shard_size
total_size += shard_size
except Exception as e:
print(f"Shard {shard} generated an exception: {e}")

pbar.update(1)

print(f"Total size: {total_size:,}")
# with open("shard2size.json", "w") as f:
# json.dump(path2size, f, indent=4)
Loading

0 comments on commit a547553

Please sign in to comment.