Skip to content

Commit

Permalink
Merge pull request #1 from climatepolicyradar/feature/pods-1356-rewri…
Browse files Browse the repository at this point in the history
…te-indexer-to-use-vespa

Feature/pods 1356 rewrite embeddings generation to allow multiple models
  • Loading branch information
kdutia authored Jun 19, 2024
2 parents 83d680e + ec5eaee commit 19bdf68
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 167 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
TARGET_LANGUAGES=en,fr

# Optional config. Defaults are set in src/config.py
LOCAL_DEVELOPMENT=false # sets cache folder to None so tests run locally
INDEX_ENCODER_CACHE_FOLDER=/models
SBERT_MODEL=msmarco-distilbert-dot-v5
ENCODING_BATCH_SIZE=32
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ RUN poetry install

# Download the sentence transformer model
RUN mkdir /models
RUN python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('sentence-transformers/msmarco-distilbert-dot-v5', cache_folder='/models')"
RUN python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('BAAI/bge-small-en-v1.5', cache_folder='/models')"

# Copy files to image
COPY ./src ./src
Expand Down
23 changes: 12 additions & 11 deletions cli/test/test_text2embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

import numpy as np
from click.testing import CliRunner
import pytest

from cli.text2embeddings import run_as_cli
from cpr_data_access.parser_models import ParserOutput

from src.config import SBERT_MODELS


def test_run_encoder_local(
test_html_file_json,
Expand Down Expand Up @@ -38,23 +41,20 @@ def test_run_encoder_local(
Path(output_dir) / "test_pdf.json",
Path(output_dir) / "test_no_content_type.json",
}
assert set(Path(output_dir).glob("*.npy")) == {
Path(output_dir) / "test_html.npy",
Path(output_dir) / "test_pdf.npy",
Path(output_dir) / "test_no_content_type.npy",
}
assert len(list(Path(output_dir).glob("*.npy"))) == 3 * len(SBERT_MODELS)

for path in Path(output_dir).glob("*.json"):
assert ParserOutput.model_validate(json.loads(path.read_text()))

for path in Path(output_dir).glob("*.npy"):
assert np.load(str(path)).shape[1] == 768
# for path in Path(output_dir).glob("*.npy"):
# assert np.load(str(path)).shape[1] == 768

# test_html has the `has_valid_text` flag set to false, so the numpy file
# should only contain a description embedding
assert np.load(str(Path(output_dir) / "test_html.npy")).shape == (1, 768)
# # test_html has the `has_valid_text` flag set to false, so the numpy file
# # should only contain a description embedding
# assert np.load(str(Path(output_dir) / "test_html.npy")).shape == (1, 768)


@pytest.mark.skip(reason="Local development only for RAG")
def test_s3_client(
s3_bucket_and_region,
pipeline_s3_objects_main,
Expand All @@ -68,6 +68,7 @@ def test_s3_client(
assert list_response["KeyCount"] == len(pipeline_s3_objects_main)


@pytest.mark.skip(reason="Local development only for RAG")
def test_run_encoder_s3(
s3_bucket_and_region,
pipeline_s3_objects_main,
Expand Down Expand Up @@ -108,7 +109,7 @@ def test_run_encoder_s3(
s3_files_npy = [file for file in files if file.endswith(".npy")]

assert len(s3_files_json) == len(pipeline_s3_objects_main)
assert len(s3_files_npy) == len(pipeline_s3_objects_main)
assert len(s3_files_npy) == len(pipeline_s3_objects_main) * len(SBERT_MODELS)

for file in s3_files_json:
file_obj = pipeline_s3_client_main.client.get_object(
Expand Down
59 changes: 35 additions & 24 deletions cli/text2embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
encode_parser_output,
get_files_to_process,
get_Text2EmbeddingsInput_array,
encoder_name_as_slug,
)
from src.s3 import check_file_exists_in_s3, write_json_to_s3, save_ndarray_to_s3_as_npy

Expand Down Expand Up @@ -186,8 +187,11 @@ def run_embeddings_generation(
inputs=tasks, remove_block_types=config.BLOCKS_TO_FILTER
)

logger.info(f"Loading sentence-transformer model {config.SBERT_MODEL}")
encoder = SBERTEncoder(config.SBERT_MODEL)
encoders = dict()

for model_name in config.SBERT_MODELS:
logger.info(f"Loading sentence-transformer model {config.SBERT_MODELS}")
encoders[model_name] = SBERTEncoder(model_name)

logger.info(
"Encoding text from documents.",
Expand All @@ -213,33 +217,40 @@ def run_embeddings_generation(
extra={"props": {"task_output_path": task_output_path, "exception": e}},
)

embeddings_output_path = os.path.join(output_dir, task.document_id + ".npy")
for encoder_name, encoder in encoders.items():
encoder_name_slug = encoder_name_as_slug(encoder_name)
embeddings_output_path = os.path.join(
output_dir, task.document_id + "__" + encoder_name_slug + ".npy"
)

file_exists = (
check_file_exists_in_s3(embeddings_output_path)
if s3
else os.path.exists(embeddings_output_path)
)
if file_exists:
logger.info(
f"Embeddings output file '{embeddings_output_path}' already exists, "
"skipping processing."
file_exists = (
check_file_exists_in_s3(embeddings_output_path)
if s3
else os.path.exists(embeddings_output_path)
)
continue
if file_exists:
logger.info(
f"Embeddings output file '{embeddings_output_path}' already exists, "
"skipping processing."
)
continue

description_embedding, text_embeddings = encode_parser_output(
encoder, task, config.ENCODING_BATCH_SIZE, device=device
)
logger.info(
f"Encoding text using model {encoder_name}.",
)
description_embedding, text_embeddings = encode_parser_output(
encoder, task, config.ENCODING_BATCH_SIZE, device=device
)

combined_embeddings = (
np.vstack([description_embedding, text_embeddings])
if text_embeddings is not None
else description_embedding.reshape(1, -1)
)
combined_embeddings = (
np.vstack([description_embedding, text_embeddings])
if text_embeddings is not None
else description_embedding.reshape(1, -1)
)

save_ndarray_to_s3_as_npy(
combined_embeddings, embeddings_output_path
) if s3 else np.save(embeddings_output_path, combined_embeddings)
save_ndarray_to_s3_as_npy(
combined_embeddings, embeddings_output_path
) if s3 else np.save(embeddings_output_path, combined_embeddings)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 19bdf68

Please sign in to comment.