From 37e5c81592c77688d7e3aef5f643e9927e9a81d5 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Fri, 1 Nov 2024 12:47:08 +0100 Subject: [PATCH] convert to numpy --- ...ual_pdf_rag_with_vespa_colpali_cloud.ipynb | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/docs/sphinx/source/examples/visual_pdf_rag_with_vespa_colpali_cloud.ipynb b/docs/sphinx/source/examples/visual_pdf_rag_with_vespa_colpali_cloud.ipynb index 624dcfbb..276dbf2f 100644 --- a/docs/sphinx/source/examples/visual_pdf_rag_with_vespa_colpali_cloud.ipynb +++ b/docs/sphinx/source/examples/visual_pdf_rag_with_vespa_colpali_cloud.ipynb @@ -993,7 +993,6 @@ " np.ndarray: Embeddings for the images, shape\n", " (len(images), processor.max_patch_length (1030 for ColPali), model.config.hidden_size (Patch embedding dimension - 128 for ColPali)).\n", " \"\"\"\n", - " embeddings_list = []\n", "\n", " def collate_fn(batch):\n", " # Batch is a list of images\n", @@ -1005,20 +1004,24 @@ " collate_fn=collate_fn,\n", " )\n", "\n", - " for batch_doc in tqdm(dataloader, desc=\"Generating embeddings\"):\n", + " embeddings_list = []\n", + " for batch in tqdm(dataloader):\n", " with torch.no_grad():\n", - " # Move batch to the device\n", - " batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}\n", - " embeddings_batch = model(**batch_doc)\n", - " embeddings_list.append(torch.unbind(embeddings_batch.to(\"cpu\")))\n", - " # Concatenate all embeddings and create a numpy array\n", - " all_embeddings = np.concatenate(embeddings_list, axis=0)\n", + " batch = {k: v.to(model.device) for k, v in batch.items()}\n", + " embeddings_batch = model(**batch)\n", + " # Convert tensor to numpy array and append to list\n", + " embeddings_list.extend(\n", + " [t.cpu().numpy() for t in torch.unbind(embeddings_batch)]\n", + " )\n", + "\n", + " # Stack all embeddings into a single numpy array\n", + " all_embeddings = np.stack(embeddings_list, axis=0)\n", " return all_embeddings" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "id": "dece992e", "metadata": { "colab": { @@ -1032,8 +1035,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Generating embeddings: 0%| | 0/10 [00:00