Skip to content

Commit

Permalink
convert to numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasht86 committed Nov 1, 2024
1 parent 4ade2f3 commit 37e5c81
Showing 1 changed file with 15 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,6 @@
" np.ndarray: Embeddings for the images, shape\n",
" (len(images), processor.max_patch_length (1030 for ColPali), model.config.hidden_size (Patch embedding dimension - 128 for ColPali)).\n",
" \"\"\"\n",
" embeddings_list = []\n",
"\n",
" def collate_fn(batch):\n",
" # Batch is a list of images\n",
Expand All @@ -1005,20 +1004,24 @@
" collate_fn=collate_fn,\n",
" )\n",
"\n",
" for batch_doc in tqdm(dataloader, desc=\"Generating embeddings\"):\n",
" embeddings_list = []\n",
" for batch in tqdm(dataloader):\n",
" with torch.no_grad():\n",
" # Move batch to the device\n",
" batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}\n",
" embeddings_batch = model(**batch_doc)\n",
" embeddings_list.append(torch.unbind(embeddings_batch.to(\"cpu\")))\n",
" # Concatenate all embeddings and create a numpy array\n",
" all_embeddings = np.concatenate(embeddings_list, axis=0)\n",
" batch = {k: v.to(model.device) for k, v in batch.items()}\n",
" embeddings_batch = model(**batch)\n",
" # Convert tensor to numpy array and append to list\n",
" embeddings_list.extend(\n",
" [t.cpu().numpy() for t in torch.unbind(embeddings_batch)]\n",
" )\n",
"\n",
" # Stack all embeddings into a single numpy array\n",
" all_embeddings = np.stack(embeddings_list, axis=0)\n",
" return all_embeddings"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"id": "dece992e",
"metadata": {
"colab": {
Expand All @@ -1032,8 +1035,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Generating embeddings: 0%| | 0/10 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n",
"Generating embeddings: 100%|██████████| 10/10 [00:21<00:00, 2.16s/it]\n"
"100%|██████████| 10/10 [00:22<00:00, 2.20s/it]\n"
]
}
],
Expand All @@ -1053,7 +1055,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 15,
"id": "eedcf944",
"metadata": {},
"outputs": [
Expand All @@ -1063,7 +1065,7 @@
"(10, 1030, 128)"
]
},
"execution_count": 36,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
Expand Down

0 comments on commit 37e5c81

Please sign in to comment.