From ac5f550192656a1696860de8bf43149313c4e5d3 Mon Sep 17 00:00:00 2001 From: zhuwenxing Date: Tue, 29 Oct 2024 15:35:37 +0800 Subject: [PATCH] test: add testcases for bm25 function Signed-off-by: zhuwenxing --- tests/requirements.txt | 4 +- tests/testcases/test_restore_backup.py | 111 ++++++++++++++++--------- tests/utils/util_common.py | 61 ++++++++++++++ tests/utils/util_pymilvus.py | 85 ++++++++++++++++++- 4 files changed, 217 insertions(+), 44 deletions(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 9845f900..174f0c07 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -12,8 +12,8 @@ allure-pytest==2.7.0 pytest-print==0.2.1 pytest-level==0.1.1 pytest-xdist==2.5.0 -pymilvus==2.4.5rc11 -pymilvus[bulk_writer]==2.4.5rc11 +pymilvus==2.5.0rc104 +pymilvus[bulk_writer]==2.5.0rc104 pytest-rerunfailures==9.1.1 git+https://github.com/Projectplace/pytest-tags ndg-httpsclient diff --git a/tests/testcases/test_restore_backup.py b/tests/testcases/test_restore_backup.py index df55dfdb..59cafc6d 100644 --- a/tests/testcases/test_restore_backup.py +++ b/tests/testcases/test_restore_backup.py @@ -6,12 +6,14 @@ import jax.numpy as jnp import random from collections import defaultdict -from pymilvus import db, list_collections, Collection, DataType +from pymilvus import db, list_collections, Collection, DataType, Function, FunctionType from base.client_base import TestcaseBase from common import common_func as cf from common import common_type as ct from common.common_type import CaseLabel from utils.util_log import test_log as log +from utils.util_common import analyze_documents +from utils.util_pymilvus import create_index_for_vector_fields from api.milvus_backup import MilvusBackupClient from faker import Faker fake_en = Faker("en_US") @@ -193,7 +195,7 @@ def test_milvus_restore_back_with_db_support(self): @pytest.mark.parametrize("include_dynamic", [True]) @pytest.mark.parametrize("include_json", [True]) @pytest.mark.tags(CaseLabel.L0) - def test_milvus_restore_back_with_json_dynamic_schema_partition_key(self, include_json, include_dynamic, include_partition_key): + def test_milvus_restore_back_with_new_dynamic_schema_and_partition_key(self, include_json, include_dynamic, include_partition_key): self._connect() name_origin = cf.gen_unique_str(prefix) back_up_name = cf.gen_unique_str(backup_prefix) @@ -651,20 +653,24 @@ def test_milvus_restore_back_with_f16_bf16_datatype(self, include_dynamic, inclu @pytest.mark.parametrize("include_partition_key", [True]) @pytest.mark.parametrize("include_dynamic", [True]) @pytest.mark.parametrize("enable_text_match", [True]) + @pytest.mark.parametrize("enable_full_text_search", [True]) @pytest.mark.tags(CaseLabel.MASTER) - def test_milvus_restore_back_with_sparse_vector_text_match_datatype(self, include_dynamic, include_partition_key, enable_text_match): + def test_milvus_restore_back_with_sparse_vector_text_match_datatype(self, include_dynamic, include_partition_key, enable_text_match, enable_full_text_search): self._connect() name_origin = cf.gen_unique_str(prefix) back_up_name = cf.gen_unique_str(backup_prefix) fields = [cf.gen_int64_field(name="int64", is_primary=True), - cf.gen_int64_field(name="key"), - cf.gen_string_field(name="text", enable_match=enable_text_match), - cf.gen_json_field(name="json"), - cf.gen_array_field(name="var_array", element_type=DataType.VARCHAR), - cf.gen_array_field(name="int_array", element_type=DataType.INT64), - cf.gen_float_vec_field(name="float_vector", dim=128), - cf.gen_sparse_vec_field(name="sparse_vector"), - ] + cf.gen_int64_field(name="key"), + cf.gen_string_field(name="text", enable_match=enable_text_match, enable_tokenizer=True), + cf.gen_json_field(name="json"), + cf.gen_array_field(name="var_array", element_type=DataType.VARCHAR), + cf.gen_array_field(name="int_array", element_type=DataType.INT64), + cf.gen_float_vec_field(name="float_vector", dim=128), + cf.gen_sparse_vec_field(name="sparse_vector"), + # cf.gen_sparse_vec_field(name="bm25_sparse_vector"), + ] + if enable_full_text_search: + fields.append(cf.gen_sparse_vec_field(name="bm25_sparse_vector")) if include_partition_key: partition_key = "key" default_schema = cf.gen_collection_schema(fields, @@ -673,40 +679,38 @@ def test_milvus_restore_back_with_sparse_vector_text_match_datatype(self, includ else: default_schema = cf.gen_collection_schema(fields, enable_dynamic_field=include_dynamic) - + if enable_full_text_search: + bm25_function = Function( + name="text_bm25_emb", + function_type=FunctionType.BM25, + input_field_names=["text"], + output_field_names=["bm25_sparse_vector"], + params={}, + ) + default_schema.add_function(bm25_function) collection_w = self.init_collection_wrap(name=name_origin, schema=default_schema, active_trace=True) nb = 3000 rng = np.random.default_rng() - data = [ - [i for i in range(nb)], - [i % 3 for i in range(nb)], - [fake_en.text() for i in range(nb)], - [{f"key_{str(i)}": i} for i in range(nb)], - [[str(x) for x in range(10)] for i in range(nb)], - [[int(x) for x in range(10)] for i in range(nb)], - [[np.float32(i) for i in range(128)] for _ in range(nb)], - [{ - d: rng.random() for d in random.sample(range(1000), random.randint(20, 30)) - } for _ in range(nb)], - ] + + data = [] + for i in range(nb): + tmp = { + "int64": i, + "key": i % 3, + "text": fake_en.text(), + "json": {f"key_{str(i)}": i}, + "var_array": [str(x) for x in range(10)], + "int_array": [int(x) for x in range(10)], + "float_vector": [np.float32(i) for i in range(128)], + "sparse_vector": { + d: rng.random() for d in random.sample(range(1000), random.randint(20, 30)) + }, + } + if include_dynamic: + tmp[f"dynamic_{str(i)}"] = i + data.append(tmp) + texts = [d["text"] for d in data] collection_w.insert(data=data) - if include_dynamic: - data = [ - { - "int64": i, - "key": i % 3, - "text": fake_en.text(), - "json": {f"key_{str(i)}": i}, - "var_array": [str(x) for x in range(10)], - "int_array": [int(x) for x in range(10)], - "float_vector": [np.float32(i) for i in range(128)], - "sparse_vector": { - d: rng.random() for d in random.sample(range(1000), random.randint(20, 30)) - }, - f"dynamic_{str(i)}": i - } for i in range(nb, nb*2) - ] - collection_w.insert(data=data) res = client.create_backup({"async": False, "backup_name": back_up_name, "collection_names": [name_origin]}) log.info(f"create_backup {res}") res = client.list_backup() @@ -727,6 +731,31 @@ def test_milvus_restore_back_with_sparse_vector_text_match_datatype(self, includ assert name_origin + suffix in res output_fields = None self.compare_collections(name_origin, name_origin + suffix, output_fields=output_fields) + # check text match and full text search in restored collection + word_freq = analyze_documents(texts) + token = word_freq.most_common(1)[0][0] + c = Collection(name=name_origin + suffix) + create_index_for_vector_fields(c) + if enable_text_match: + res = c.query( + expr=f"TextMatch(text, '{token}')", + output_fields=["text"], + limit=1 + ) + assert len(res) == 1 + for r in res: + assert token in r["text"] + if enable_full_text_search: + search_data = [fake_en.text()+f" {token} "] + res = c.search( + data=search_data, + anns_field="bm25_sparse_vector", + output_fields=["text"], + limit=1 + ) + assert len(res) == 1 + for r in res: + assert len(r) == 1 res = client.delete_backup(back_up_name) res = client.list_backup() if "data" in res: diff --git a/tests/utils/util_common.py b/tests/utils/util_common.py index e2dcccb1..6752d38d 100644 --- a/tests/utils/util_common.py +++ b/tests/utils/util_common.py @@ -1,7 +1,68 @@ from yaml import full_load import json +from collections import Counter +from bm25s.tokenization import Tokenizer +import jieba +import re from utils.util_log import test_log as log + +def custom_tokenizer(language="en"): + def remove_punctuation(text): + text = text.strip() + text = text.replace("\n", " ") + return re.sub(r'[^\w\s]', ' ', text) + + # Tokenize the corpus + def jieba_split(text): + text_without_punctuation = remove_punctuation(text) + return jieba.lcut(text_without_punctuation) + + def blank_space_split(text): + text_without_punctuation = remove_punctuation(text) + return text_without_punctuation.split() + + stopwords = [" "] + stemmer = None + if language in ["zh", "cn", "chinese"]: + splitter = jieba_split + tokenizer = Tokenizer( + stemmer=stemmer, splitter=splitter, stopwords=stopwords + ) + else: + splitter = blank_space_split + tokenizer = Tokenizer( + stemmer=stemmer, splitter= splitter, stopwords=stopwords + ) + return tokenizer + + +def analyze_documents(texts, language="en"): + + tokenizer = custom_tokenizer(language) + new_texts = [] + for text in texts: + if isinstance(text, str): + new_texts.append(text) + # Tokenize the corpus + tokenized = tokenizer.tokenize(new_texts, return_as="tuple") + # log.info(f"Tokenized: {tokenized}") + # Create a frequency counter + freq = Counter() + + # Count the frequency of each token + for doc_ids in tokenized.ids: + freq.update(doc_ids) + # Create a reverse vocabulary mapping + id_to_word = {id: word for word, id in tokenized.vocab.items()} + + # Convert token ids back to words + word_freq = Counter({id_to_word[token_id]: count for token_id, count in freq.items()}) + log.debug(f"word freq {word_freq.most_common(10)}") + + return word_freq + + def gen_experiment_config(yaml): """load the yaml file of chaos experiment""" with open(yaml) as f: diff --git a/tests/utils/util_pymilvus.py b/tests/utils/util_pymilvus.py index 34e47546..a718d485 100644 --- a/tests/utils/util_pymilvus.py +++ b/tests/utils/util_pymilvus.py @@ -8,7 +8,7 @@ import numpy as np import requests from sklearn import preprocessing -from pymilvus import Milvus, DataType +from pymilvus import Milvus, DataType, FunctionType from utils.util_log import test_log as log from utils.util_k8s import init_k8s_client_config @@ -56,6 +56,89 @@ ] +DEFAULT_FLOAT_INDEX_PARAM = {"index_type": "HNSW", "metric_type": "L2", "params": {"M": 48, "efConstruction": 500}} +DEFAULT_FLOAT_SEARCH_PARAM = {"metric_type": "L2", "params": {"ef": 64}} +DEFAULT_BINARY_INDEX_PARAM = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD", "params": {"M": 48}} +DEFAULT_BINARY_SEARCH_PARAM = {"metric_type": "JACCARD", "params": {"nprobe": 10}} +DEFAULT_SPARSE_INDEX_PARAM = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP", "params": {}} +DEFAULT_SPARSE_SEARCH_PARAM = {"metric_type": "IP", "params": {}} +DEFAULT_BM25_INDEX_PARAM = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "BM25", "params": {"bm25_k1": 1.5, "bm25_b": 0.75}} +DEFAULT_BM25_SEARCH_PARAM = {"metric_type": "BM25", "params": {}} + + +def get_float_vec_field_name_list(schema): + vec_fields = [] + fields = schema.fields + for field in fields: + if field.dtype in [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR]: + vec_fields.append(field.name) + return vec_fields + +def get_binary_vec_field_name_list(schema): + vec_fields = [] + fields = schema.fields + for field in fields: + if field.dtype in [DataType.BINARY_VECTOR]: + vec_fields.append(field.name) + return vec_fields + +def get_bm25_vec_field_name_list(schema=None): + if not hasattr(schema, "functions"): + return [] + functions = schema.functions + bm25_func = [func for func in functions if func.type == FunctionType.BM25] + bm25_outputs = [] + for func in bm25_func: + bm25_outputs.extend(func.output_field_names) + bm25_outputs = list(set(bm25_outputs)) + + return bm25_outputs + +def get_sparse_vec_field_name_list(schema): + # SPARSE_FLOAT_VECTOR but not in BM25 + vec_fields = [] + bm25_fields = get_bm25_vec_field_name_list(schema) + fields = schema.fields + for field in fields: + if field.dtype in [DataType.SPARSE_FLOAT_VECTOR]: + vec_fields.append(field.name) + return list(set(vec_fields) - set(bm25_fields)) + + +def create_index_for_vector_fields(collection): + schema = collection.schema + float_vector_fields = get_float_vec_field_name_list(schema) + binary_vector_fields = get_binary_vec_field_name_list(schema) + sparse_vector_fields = get_sparse_vec_field_name_list(schema) + bm25_vector_fields = get_bm25_vec_field_name_list(schema) + indexes = [index.to_dict() for index in collection.indexes] + indexed_fields = [index['field'] for index in indexes] + for field_name in float_vector_fields: + if field_name in indexed_fields: + continue + collection.create_index(field_name, DEFAULT_FLOAT_INDEX_PARAM) + for field_name in binary_vector_fields: + if field_name in indexed_fields: + continue + collection.create_index(field_name, DEFAULT_BINARY_INDEX_PARAM) + for field_name in sparse_vector_fields: + if field_name in indexed_fields: + continue + collection.create_index(field_name, DEFAULT_SPARSE_INDEX_PARAM) + for field_name in bm25_vector_fields: + if field_name in indexed_fields: + continue + collection.create_index(field_name, DEFAULT_BM25_INDEX_PARAM) + + + + + + + + + + def create_target_index(index, field_name): index["field_name"] = field_name