Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add testcases for bm25 function #447

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ allure-pytest==2.7.0
pytest-print==0.2.1
pytest-level==0.1.1
pytest-xdist==2.5.0
pymilvus==2.4.5rc11
pymilvus[bulk_writer]==2.4.5rc11
pymilvus==2.5.0
pymilvus[bulk_writer]==2.5.0
pytest-rerunfailures==9.1.1
git+https://github.com/Projectplace/pytest-tags
ndg-httpsclient
Expand Down
111 changes: 70 additions & 41 deletions tests/testcases/test_restore_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import jax.numpy as jnp
import random
from collections import defaultdict
from pymilvus import db, list_collections, Collection, DataType
from pymilvus import db, list_collections, Collection, DataType, Function, FunctionType
from base.client_base import TestcaseBase
from common import common_func as cf
from common import common_type as ct
from common.common_type import CaseLabel
from utils.util_log import test_log as log
from utils.util_common import analyze_documents
from utils.util_pymilvus import create_index_for_vector_fields
from api.milvus_backup import MilvusBackupClient
from faker import Faker
fake_en = Faker("en_US")
Expand Down Expand Up @@ -193,7 +195,7 @@ def test_milvus_restore_back_with_db_support(self):
@pytest.mark.parametrize("include_dynamic", [True])
@pytest.mark.parametrize("include_json", [True])
@pytest.mark.tags(CaseLabel.L0)
def test_milvus_restore_back_with_json_dynamic_schema_partition_key(self, include_json, include_dynamic, include_partition_key):
def test_milvus_restore_back_with_new_dynamic_schema_and_partition_key(self, include_json, include_dynamic, include_partition_key):
self._connect()
name_origin = cf.gen_unique_str(prefix)
back_up_name = cf.gen_unique_str(backup_prefix)
Expand Down Expand Up @@ -651,20 +653,24 @@ def test_milvus_restore_back_with_f16_bf16_datatype(self, include_dynamic, inclu
@pytest.mark.parametrize("include_partition_key", [True])
@pytest.mark.parametrize("include_dynamic", [True])
@pytest.mark.parametrize("enable_text_match", [True])
@pytest.mark.parametrize("enable_full_text_search", [True])
@pytest.mark.tags(CaseLabel.MASTER)
def test_milvus_restore_back_with_sparse_vector_text_match_datatype(self, include_dynamic, include_partition_key, enable_text_match):
def test_milvus_restore_back_with_sparse_vector_text_match_datatype(self, include_dynamic, include_partition_key, enable_text_match, enable_full_text_search):
self._connect()
name_origin = cf.gen_unique_str(prefix)
back_up_name = cf.gen_unique_str(backup_prefix)
fields = [cf.gen_int64_field(name="int64", is_primary=True),
cf.gen_int64_field(name="key"),
cf.gen_string_field(name="text", enable_match=enable_text_match),
cf.gen_json_field(name="json"),
cf.gen_array_field(name="var_array", element_type=DataType.VARCHAR),
cf.gen_array_field(name="int_array", element_type=DataType.INT64),
cf.gen_float_vec_field(name="float_vector", dim=128),
cf.gen_sparse_vec_field(name="sparse_vector"),
]
cf.gen_int64_field(name="key"),
cf.gen_string_field(name="text", enable_match=enable_text_match, enable_analyzer=True),
cf.gen_json_field(name="json"),
cf.gen_array_field(name="var_array", element_type=DataType.VARCHAR),
cf.gen_array_field(name="int_array", element_type=DataType.INT64),
cf.gen_float_vec_field(name="float_vector", dim=128),
cf.gen_sparse_vec_field(name="sparse_vector"),
# cf.gen_sparse_vec_field(name="bm25_sparse_vector"),
]
if enable_full_text_search:
fields.append(cf.gen_sparse_vec_field(name="bm25_sparse_vector"))
if include_partition_key:
partition_key = "key"
default_schema = cf.gen_collection_schema(fields,
Expand All @@ -673,40 +679,38 @@ def test_milvus_restore_back_with_sparse_vector_text_match_datatype(self, includ
else:
default_schema = cf.gen_collection_schema(fields,
enable_dynamic_field=include_dynamic)

if enable_full_text_search:
bm25_function = Function(
name="text_bm25_emb",
function_type=FunctionType.BM25,
input_field_names=["text"],
output_field_names=["bm25_sparse_vector"],
params={},
)
default_schema.add_function(bm25_function)
collection_w = self.init_collection_wrap(name=name_origin, schema=default_schema, active_trace=True)
nb = 3000
rng = np.random.default_rng()
data = [
[i for i in range(nb)],
[i % 3 for i in range(nb)],
[fake_en.text() for i in range(nb)],
[{f"key_{str(i)}": i} for i in range(nb)],
[[str(x) for x in range(10)] for i in range(nb)],
[[int(x) for x in range(10)] for i in range(nb)],
[[np.float32(i) for i in range(128)] for _ in range(nb)],
[{
d: rng.random() for d in random.sample(range(1000), random.randint(20, 30))
} for _ in range(nb)],
]

data = []
for i in range(nb):
tmp = {
"int64": i,
"key": i % 3,
"text": fake_en.text(),
"json": {f"key_{str(i)}": i},
"var_array": [str(x) for x in range(10)],
"int_array": [int(x) for x in range(10)],
"float_vector": [np.float32(i) for i in range(128)],
"sparse_vector": {
d: rng.random() for d in random.sample(range(1000), random.randint(20, 30))
},
}
if include_dynamic:
tmp[f"dynamic_{str(i)}"] = i
data.append(tmp)
texts = [d["text"] for d in data]
collection_w.insert(data=data)
if include_dynamic:
data = [
{
"int64": i,
"key": i % 3,
"text": fake_en.text(),
"json": {f"key_{str(i)}": i},
"var_array": [str(x) for x in range(10)],
"int_array": [int(x) for x in range(10)],
"float_vector": [np.float32(i) for i in range(128)],
"sparse_vector": {
d: rng.random() for d in random.sample(range(1000), random.randint(20, 30))
},
f"dynamic_{str(i)}": i
} for i in range(nb, nb*2)
]
collection_w.insert(data=data)
res = client.create_backup({"async": False, "backup_name": back_up_name, "collection_names": [name_origin]})
log.info(f"create_backup {res}")
res = client.list_backup()
Expand All @@ -727,6 +731,31 @@ def test_milvus_restore_back_with_sparse_vector_text_match_datatype(self, includ
assert name_origin + suffix in res
output_fields = None
self.compare_collections(name_origin, name_origin + suffix, output_fields=output_fields)
# check text match and full text search in restored collection
word_freq = analyze_documents(texts)
token = word_freq.most_common(1)[0][0]
c = Collection(name=name_origin + suffix)
create_index_for_vector_fields(c)
if enable_text_match:
res = c.query(
expr=f"TextMatch(text, '{token}')",
output_fields=["text"],
limit=1
)
assert len(res) == 1
for r in res:
assert token in r["text"]
if enable_full_text_search:
search_data = [fake_en.text()+f" {token} "]
res = c.search(
data=search_data,
anns_field="bm25_sparse_vector",
output_fields=["text"],
limit=1
)
assert len(res) == 1
for r in res:
assert len(r) == 1
res = client.delete_backup(back_up_name)
res = client.list_backup()
if "data" in res:
Expand Down
61 changes: 61 additions & 0 deletions tests/utils/util_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,68 @@
from yaml import full_load
import json
from collections import Counter
from bm25s.tokenization import Tokenizer
import jieba
import re
from utils.util_log import test_log as log


def custom_tokenizer(language="en"):
def remove_punctuation(text):
text = text.strip()
text = text.replace("\n", " ")
return re.sub(r'[^\w\s]', ' ', text)

# Tokenize the corpus
def jieba_split(text):
text_without_punctuation = remove_punctuation(text)
return jieba.lcut(text_without_punctuation)

def blank_space_split(text):
text_without_punctuation = remove_punctuation(text)
return text_without_punctuation.split()

stopwords = [" "]
stemmer = None
if language in ["zh", "cn", "chinese"]:
splitter = jieba_split
tokenizer = Tokenizer(
stemmer=stemmer, splitter=splitter, stopwords=stopwords
)
else:
splitter = blank_space_split
tokenizer = Tokenizer(
stemmer=stemmer, splitter= splitter, stopwords=stopwords
)
return tokenizer


def analyze_documents(texts, language="en"):

tokenizer = custom_tokenizer(language)
new_texts = []
for text in texts:
if isinstance(text, str):
new_texts.append(text)
# Tokenize the corpus
tokenized = tokenizer.tokenize(new_texts, return_as="tuple")
# log.info(f"Tokenized: {tokenized}")
# Create a frequency counter
freq = Counter()

# Count the frequency of each token
for doc_ids in tokenized.ids:
freq.update(doc_ids)
# Create a reverse vocabulary mapping
id_to_word = {id: word for word, id in tokenized.vocab.items()}

# Convert token ids back to words
word_freq = Counter({id_to_word[token_id]: count for token_id, count in freq.items()})
log.debug(f"word freq {word_freq.most_common(10)}")

return word_freq


def gen_experiment_config(yaml):
"""load the yaml file of chaos experiment"""
with open(yaml) as f:
Expand Down
85 changes: 84 additions & 1 deletion tests/utils/util_pymilvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import requests
from sklearn import preprocessing
from pymilvus import Milvus, DataType
from pymilvus import Milvus, DataType, FunctionType
from utils.util_log import test_log as log
from utils.util_k8s import init_k8s_client_config

Expand Down Expand Up @@ -56,6 +56,89 @@
]


DEFAULT_FLOAT_INDEX_PARAM = {"index_type": "HNSW", "metric_type": "L2", "params": {"M": 48, "efConstruction": 500}}
DEFAULT_FLOAT_SEARCH_PARAM = {"metric_type": "L2", "params": {"ef": 64}}
DEFAULT_BINARY_INDEX_PARAM = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD", "params": {"M": 48}}
DEFAULT_BINARY_SEARCH_PARAM = {"metric_type": "JACCARD", "params": {"nprobe": 10}}
DEFAULT_SPARSE_INDEX_PARAM = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP", "params": {}}
DEFAULT_SPARSE_SEARCH_PARAM = {"metric_type": "IP", "params": {}}
DEFAULT_BM25_INDEX_PARAM = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "BM25", "params": {"bm25_k1": 1.5, "bm25_b": 0.75}}
DEFAULT_BM25_SEARCH_PARAM = {"metric_type": "BM25", "params": {}}


def get_float_vec_field_name_list(schema):
vec_fields = []
fields = schema.fields
for field in fields:
if field.dtype in [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR]:
vec_fields.append(field.name)
return vec_fields

def get_binary_vec_field_name_list(schema):
vec_fields = []
fields = schema.fields
for field in fields:
if field.dtype in [DataType.BINARY_VECTOR]:
vec_fields.append(field.name)
return vec_fields

def get_bm25_vec_field_name_list(schema=None):
if not hasattr(schema, "functions"):
return []
functions = schema.functions
bm25_func = [func for func in functions if func.type == FunctionType.BM25]
bm25_outputs = []
for func in bm25_func:
bm25_outputs.extend(func.output_field_names)
bm25_outputs = list(set(bm25_outputs))

return bm25_outputs

def get_sparse_vec_field_name_list(schema):
# SPARSE_FLOAT_VECTOR but not in BM25
vec_fields = []
bm25_fields = get_bm25_vec_field_name_list(schema)
fields = schema.fields
for field in fields:
if field.dtype in [DataType.SPARSE_FLOAT_VECTOR]:
vec_fields.append(field.name)
return list(set(vec_fields) - set(bm25_fields))


def create_index_for_vector_fields(collection):
schema = collection.schema
float_vector_fields = get_float_vec_field_name_list(schema)
binary_vector_fields = get_binary_vec_field_name_list(schema)
sparse_vector_fields = get_sparse_vec_field_name_list(schema)
bm25_vector_fields = get_bm25_vec_field_name_list(schema)
indexes = [index.to_dict() for index in collection.indexes]
indexed_fields = [index['field'] for index in indexes]
for field_name in float_vector_fields:
if field_name in indexed_fields:
continue
collection.create_index(field_name, DEFAULT_FLOAT_INDEX_PARAM)
for field_name in binary_vector_fields:
if field_name in indexed_fields:
continue
collection.create_index(field_name, DEFAULT_BINARY_INDEX_PARAM)
for field_name in sparse_vector_fields:
if field_name in indexed_fields:
continue
collection.create_index(field_name, DEFAULT_SPARSE_INDEX_PARAM)
for field_name in bm25_vector_fields:
if field_name in indexed_fields:
continue
collection.create_index(field_name, DEFAULT_BM25_INDEX_PARAM)










def create_target_index(index, field_name):
index["field_name"] = field_name

Expand Down
Loading