Skip to content

Commit

Permalink
test: add testcases for bm25 function
Browse files Browse the repository at this point in the history
Signed-off-by: zhuwenxing <[email protected]>
  • Loading branch information
zhuwenxing committed Dec 20, 2024
1 parent 6b29a3a commit ac5f550
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 44 deletions.
4 changes: 2 additions & 2 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ allure-pytest==2.7.0
pytest-print==0.2.1
pytest-level==0.1.1
pytest-xdist==2.5.0
pymilvus==2.4.5rc11
pymilvus[bulk_writer]==2.4.5rc11
pymilvus==2.5.0rc104
pymilvus[bulk_writer]==2.5.0rc104
pytest-rerunfailures==9.1.1
git+https://github.com/Projectplace/pytest-tags
ndg-httpsclient
Expand Down
111 changes: 70 additions & 41 deletions tests/testcases/test_restore_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import jax.numpy as jnp
import random
from collections import defaultdict
from pymilvus import db, list_collections, Collection, DataType
from pymilvus import db, list_collections, Collection, DataType, Function, FunctionType
from base.client_base import TestcaseBase
from common import common_func as cf
from common import common_type as ct
from common.common_type import CaseLabel
from utils.util_log import test_log as log
from utils.util_common import analyze_documents
from utils.util_pymilvus import create_index_for_vector_fields
from api.milvus_backup import MilvusBackupClient
from faker import Faker
fake_en = Faker("en_US")
Expand Down Expand Up @@ -193,7 +195,7 @@ def test_milvus_restore_back_with_db_support(self):
@pytest.mark.parametrize("include_dynamic", [True])
@pytest.mark.parametrize("include_json", [True])
@pytest.mark.tags(CaseLabel.L0)
def test_milvus_restore_back_with_json_dynamic_schema_partition_key(self, include_json, include_dynamic, include_partition_key):
def test_milvus_restore_back_with_new_dynamic_schema_and_partition_key(self, include_json, include_dynamic, include_partition_key):
self._connect()
name_origin = cf.gen_unique_str(prefix)
back_up_name = cf.gen_unique_str(backup_prefix)
Expand Down Expand Up @@ -651,20 +653,24 @@ def test_milvus_restore_back_with_f16_bf16_datatype(self, include_dynamic, inclu
@pytest.mark.parametrize("include_partition_key", [True])
@pytest.mark.parametrize("include_dynamic", [True])
@pytest.mark.parametrize("enable_text_match", [True])
@pytest.mark.parametrize("enable_full_text_search", [True])
@pytest.mark.tags(CaseLabel.MASTER)
def test_milvus_restore_back_with_sparse_vector_text_match_datatype(self, include_dynamic, include_partition_key, enable_text_match):
def test_milvus_restore_back_with_sparse_vector_text_match_datatype(self, include_dynamic, include_partition_key, enable_text_match, enable_full_text_search):
self._connect()
name_origin = cf.gen_unique_str(prefix)
back_up_name = cf.gen_unique_str(backup_prefix)
fields = [cf.gen_int64_field(name="int64", is_primary=True),
cf.gen_int64_field(name="key"),
cf.gen_string_field(name="text", enable_match=enable_text_match),
cf.gen_json_field(name="json"),
cf.gen_array_field(name="var_array", element_type=DataType.VARCHAR),
cf.gen_array_field(name="int_array", element_type=DataType.INT64),
cf.gen_float_vec_field(name="float_vector", dim=128),
cf.gen_sparse_vec_field(name="sparse_vector"),
]
cf.gen_int64_field(name="key"),
cf.gen_string_field(name="text", enable_match=enable_text_match, enable_tokenizer=True),
cf.gen_json_field(name="json"),
cf.gen_array_field(name="var_array", element_type=DataType.VARCHAR),
cf.gen_array_field(name="int_array", element_type=DataType.INT64),
cf.gen_float_vec_field(name="float_vector", dim=128),
cf.gen_sparse_vec_field(name="sparse_vector"),
# cf.gen_sparse_vec_field(name="bm25_sparse_vector"),
]
if enable_full_text_search:
fields.append(cf.gen_sparse_vec_field(name="bm25_sparse_vector"))
if include_partition_key:
partition_key = "key"
default_schema = cf.gen_collection_schema(fields,
Expand All @@ -673,40 +679,38 @@ def test_milvus_restore_back_with_sparse_vector_text_match_datatype(self, includ
else:
default_schema = cf.gen_collection_schema(fields,
enable_dynamic_field=include_dynamic)

if enable_full_text_search:
bm25_function = Function(
name="text_bm25_emb",
function_type=FunctionType.BM25,
input_field_names=["text"],
output_field_names=["bm25_sparse_vector"],
params={},
)
default_schema.add_function(bm25_function)
collection_w = self.init_collection_wrap(name=name_origin, schema=default_schema, active_trace=True)
nb = 3000
rng = np.random.default_rng()
data = [
[i for i in range(nb)],
[i % 3 for i in range(nb)],
[fake_en.text() for i in range(nb)],
[{f"key_{str(i)}": i} for i in range(nb)],
[[str(x) for x in range(10)] for i in range(nb)],
[[int(x) for x in range(10)] for i in range(nb)],
[[np.float32(i) for i in range(128)] for _ in range(nb)],
[{
d: rng.random() for d in random.sample(range(1000), random.randint(20, 30))
} for _ in range(nb)],
]

data = []
for i in range(nb):
tmp = {
"int64": i,
"key": i % 3,
"text": fake_en.text(),
"json": {f"key_{str(i)}": i},
"var_array": [str(x) for x in range(10)],
"int_array": [int(x) for x in range(10)],
"float_vector": [np.float32(i) for i in range(128)],
"sparse_vector": {
d: rng.random() for d in random.sample(range(1000), random.randint(20, 30))
},
}
if include_dynamic:
tmp[f"dynamic_{str(i)}"] = i
data.append(tmp)
texts = [d["text"] for d in data]
collection_w.insert(data=data)
if include_dynamic:
data = [
{
"int64": i,
"key": i % 3,
"text": fake_en.text(),
"json": {f"key_{str(i)}": i},
"var_array": [str(x) for x in range(10)],
"int_array": [int(x) for x in range(10)],
"float_vector": [np.float32(i) for i in range(128)],
"sparse_vector": {
d: rng.random() for d in random.sample(range(1000), random.randint(20, 30))
},
f"dynamic_{str(i)}": i
} for i in range(nb, nb*2)
]
collection_w.insert(data=data)
res = client.create_backup({"async": False, "backup_name": back_up_name, "collection_names": [name_origin]})
log.info(f"create_backup {res}")
res = client.list_backup()
Expand All @@ -727,6 +731,31 @@ def test_milvus_restore_back_with_sparse_vector_text_match_datatype(self, includ
assert name_origin + suffix in res
output_fields = None
self.compare_collections(name_origin, name_origin + suffix, output_fields=output_fields)
# check text match and full text search in restored collection
word_freq = analyze_documents(texts)
token = word_freq.most_common(1)[0][0]
c = Collection(name=name_origin + suffix)
create_index_for_vector_fields(c)
if enable_text_match:
res = c.query(
expr=f"TextMatch(text, '{token}')",
output_fields=["text"],
limit=1
)
assert len(res) == 1
for r in res:
assert token in r["text"]
if enable_full_text_search:
search_data = [fake_en.text()+f" {token} "]
res = c.search(
data=search_data,
anns_field="bm25_sparse_vector",
output_fields=["text"],
limit=1
)
assert len(res) == 1
for r in res:
assert len(r) == 1
res = client.delete_backup(back_up_name)
res = client.list_backup()
if "data" in res:
Expand Down
61 changes: 61 additions & 0 deletions tests/utils/util_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,68 @@
from yaml import full_load
import json
from collections import Counter
from bm25s.tokenization import Tokenizer
import jieba
import re
from utils.util_log import test_log as log


def custom_tokenizer(language="en"):
def remove_punctuation(text):
text = text.strip()
text = text.replace("\n", " ")
return re.sub(r'[^\w\s]', ' ', text)

# Tokenize the corpus
def jieba_split(text):
text_without_punctuation = remove_punctuation(text)
return jieba.lcut(text_without_punctuation)

def blank_space_split(text):
text_without_punctuation = remove_punctuation(text)
return text_without_punctuation.split()

stopwords = [" "]
stemmer = None
if language in ["zh", "cn", "chinese"]:
splitter = jieba_split
tokenizer = Tokenizer(
stemmer=stemmer, splitter=splitter, stopwords=stopwords
)
else:
splitter = blank_space_split
tokenizer = Tokenizer(
stemmer=stemmer, splitter= splitter, stopwords=stopwords
)
return tokenizer


def analyze_documents(texts, language="en"):

tokenizer = custom_tokenizer(language)
new_texts = []
for text in texts:
if isinstance(text, str):
new_texts.append(text)
# Tokenize the corpus
tokenized = tokenizer.tokenize(new_texts, return_as="tuple")
# log.info(f"Tokenized: {tokenized}")
# Create a frequency counter
freq = Counter()

# Count the frequency of each token
for doc_ids in tokenized.ids:
freq.update(doc_ids)
# Create a reverse vocabulary mapping
id_to_word = {id: word for word, id in tokenized.vocab.items()}

# Convert token ids back to words
word_freq = Counter({id_to_word[token_id]: count for token_id, count in freq.items()})
log.debug(f"word freq {word_freq.most_common(10)}")

return word_freq


def gen_experiment_config(yaml):
"""load the yaml file of chaos experiment"""
with open(yaml) as f:
Expand Down
85 changes: 84 additions & 1 deletion tests/utils/util_pymilvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import requests
from sklearn import preprocessing
from pymilvus import Milvus, DataType
from pymilvus import Milvus, DataType, FunctionType
from utils.util_log import test_log as log
from utils.util_k8s import init_k8s_client_config

Expand Down Expand Up @@ -56,6 +56,89 @@
]


DEFAULT_FLOAT_INDEX_PARAM = {"index_type": "HNSW", "metric_type": "L2", "params": {"M": 48, "efConstruction": 500}}
DEFAULT_FLOAT_SEARCH_PARAM = {"metric_type": "L2", "params": {"ef": 64}}
DEFAULT_BINARY_INDEX_PARAM = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD", "params": {"M": 48}}
DEFAULT_BINARY_SEARCH_PARAM = {"metric_type": "JACCARD", "params": {"nprobe": 10}}
DEFAULT_SPARSE_INDEX_PARAM = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP", "params": {}}
DEFAULT_SPARSE_SEARCH_PARAM = {"metric_type": "IP", "params": {}}
DEFAULT_BM25_INDEX_PARAM = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "BM25", "params": {"bm25_k1": 1.5, "bm25_b": 0.75}}
DEFAULT_BM25_SEARCH_PARAM = {"metric_type": "BM25", "params": {}}


def get_float_vec_field_name_list(schema):
vec_fields = []
fields = schema.fields
for field in fields:
if field.dtype in [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR]:
vec_fields.append(field.name)
return vec_fields

def get_binary_vec_field_name_list(schema):
vec_fields = []
fields = schema.fields
for field in fields:
if field.dtype in [DataType.BINARY_VECTOR]:
vec_fields.append(field.name)
return vec_fields

def get_bm25_vec_field_name_list(schema=None):
if not hasattr(schema, "functions"):
return []
functions = schema.functions
bm25_func = [func for func in functions if func.type == FunctionType.BM25]
bm25_outputs = []
for func in bm25_func:
bm25_outputs.extend(func.output_field_names)
bm25_outputs = list(set(bm25_outputs))

return bm25_outputs

def get_sparse_vec_field_name_list(schema):
# SPARSE_FLOAT_VECTOR but not in BM25
vec_fields = []
bm25_fields = get_bm25_vec_field_name_list(schema)
fields = schema.fields
for field in fields:
if field.dtype in [DataType.SPARSE_FLOAT_VECTOR]:
vec_fields.append(field.name)
return list(set(vec_fields) - set(bm25_fields))


def create_index_for_vector_fields(collection):
schema = collection.schema
float_vector_fields = get_float_vec_field_name_list(schema)
binary_vector_fields = get_binary_vec_field_name_list(schema)
sparse_vector_fields = get_sparse_vec_field_name_list(schema)
bm25_vector_fields = get_bm25_vec_field_name_list(schema)
indexes = [index.to_dict() for index in collection.indexes]
indexed_fields = [index['field'] for index in indexes]
for field_name in float_vector_fields:
if field_name in indexed_fields:
continue
collection.create_index(field_name, DEFAULT_FLOAT_INDEX_PARAM)
for field_name in binary_vector_fields:
if field_name in indexed_fields:
continue
collection.create_index(field_name, DEFAULT_BINARY_INDEX_PARAM)
for field_name in sparse_vector_fields:
if field_name in indexed_fields:
continue
collection.create_index(field_name, DEFAULT_SPARSE_INDEX_PARAM)
for field_name in bm25_vector_fields:
if field_name in indexed_fields:
continue
collection.create_index(field_name, DEFAULT_BM25_INDEX_PARAM)










def create_target_index(index, field_name):
index["field_name"] = field_name

Expand Down

0 comments on commit ac5f550

Please sign in to comment.