Skip to content

Commit

Permalink
fix: index filtering (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
zanussbaum authored Jun 11, 2024
1 parent 56cfae8 commit a4c197f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion scripts/text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ The dataset should be in the following jsonl format
To filter an existing dataset, run

```bash
torchrun --nproc-per-node=<num_gpus> --dataset=<path_to_dataset_files_or_directory> --output_dir=<path_where_to_save_filtered_dataset> --query_key=<query_key_of_jsonl_file> --document_key=<document_of_key_jsonl_file>
torchrun --nproc-per-node=<num_gpus> --dataset=<path_to_dataset_files_or_directory> --output_dir=<path_where_to_save_filtered_dataset> --query_key=<query_key_of_jsonl_file> --document_key=<document_of_key_jsonl_file> index_filtering.py
```

NOTE: You most likely we want to install `faiss-gpu`. To do so on a GPU with Cuda 12+, please follow [INSTALL_FAISS.md](INSTALL_FAISS.md).
Expand Down
6 changes: 3 additions & 3 deletions scripts/text/index_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

from contrastors.models.encoder import BertConfig, BertModel, bert_config_to_gpt2_config
from contrastors.models.encoder import BertConfig, NomicBertModel, bert_config_to_nomic_config


def parse_args():
Expand Down Expand Up @@ -299,9 +299,9 @@ def filter_points(id2embeddings, batch_size=256):

model_name = "thenlper/gte-base"
hf_config = BertConfig.from_pretrained(model_name)
config = bert_config_to_gpt2_config(hf_config)
config = bert_config_to_nomic_config(hf_config)
model = (
BertModel.from_pretrained(model_name, config=config, add_pooling_layer=False)
NomicBertModel.from_pretrained(model_name, config=config, add_pooling_layer=False)
.to(f"cuda:{dist.get_rank()}")
.to(dtype=torch.float16)
)
Expand Down

0 comments on commit a4c197f

Please sign in to comment.