Skip to content

Commit

Permalink
Fix bug: failed to insert binary vector (#1491)
Browse files Browse the repository at this point in the history
Signed-off-by: zhenshan.cao <[email protected]>
  • Loading branch information
czs007 authored May 30, 2023
1 parent 8926c67 commit bdfcd7b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 5 deletions.
69 changes: 69 additions & 0 deletions examples/binary_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import time
import random
import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)


bin_index_types = ["BIN_FLAT", "BIN_IVF_FLAT"]

default_bin_index_params = [{"nlist": 128}, {"nlist": 128}]

def gen_binary_vectors(num, dim):
raw_vectors = []
binary_vectors = []
for _ in range(num):
raw_vector = [random.randint(0, 1) for _ in range(dim)]
raw_vectors.append(raw_vector)
# packs a binary-valued array into bits in a unit8 array, and bytes array_of_ints
binary_vectors.append(bytes(np.packbits(raw_vector, axis=-1).tolist()))
return raw_vectors, binary_vectors


def binary_vector_search():
connections.connect()
int64_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, auto_id=True)
dim = 128
nb = 3000
vector_field_name = "binary_vector"
binary_vector = FieldSchema(name=vector_field_name, dtype=DataType.BINARY_VECTOR, dim=dim)
schema = CollectionSchema(fields=[int64_field, binary_vector], enable_dynamic_field=True)

has = utility.has_collection("hello_milvus")
if has:
hello_milvus = Collection("hello_milvus_bin")
hello_milvus.drop()
else:
hello_milvus = Collection("hello_milvus_bin", schema)

_, vectors = gen_binary_vectors(nb, dim)
rows = [
{vector_field_name: vectors[0]},
{vector_field_name: vectors[1]},
{vector_field_name: vectors[2]},
{vector_field_name: vectors[3]},
{vector_field_name: vectors[4]},
{vector_field_name: vectors[5]},
]

hello_milvus.insert(rows)
hello_milvus.flush()
for i, index_type in enumerate(bin_index_types):
index_params = default_bin_index_params[i]
hello_milvus.create_index(vector_field_name,
index_params={"index_type": index_type, "params": index_params, "metric_type": "HAMMING"})
hello_milvus.load()
print("index_type = ", index_type)
res = hello_milvus.search(vectors[:1], vector_field_name, {"metric_type": "HAMMING"}, limit=1)
print("res = ", res)
hello_milvus.release()
hello_milvus.drop_index()
hello_milvus.drop()


if __name__ == "__main__":
binary_vector_search()
10 changes: 5 additions & 5 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def pack_field_value_to_field_data(field_value, field_data, field_info):
field_data.vectors.float_vector.data.extend(field_value)
elif field_type in (DataType.BINARY_VECTOR,):
field_data.vectors.dim = len(field_value) * 8
field_data.vectors.binary_vector.data.append(b''.join(field_value))
field_data.vectors.binary_vector += bytes(field_value)
elif field_type in (DataType.VARCHAR,):
field_data.scalars.string_data.data.append(
convert_to_str_array(field_value, field_info, True))
Expand Down Expand Up @@ -204,10 +204,10 @@ def extract_row_data_from_fields_data(fields_data, index, dynamic_output_fields=
start_pos:end_pos]]
elif field_data.type == DataType.BINARY_VECTOR:
dim = field_data.vectors.dim
if len(field_data.vectors.binary_vector.data) >= index * (dim / 8):
start_pos = index * (dim / 8)
end_pos = (index + 1) * (dim / 8)
if len(field_data.vectors.binary_vector) >= index * (dim // 8):
start_pos = index * (dim // 8)
end_pos = (index + 1) * (dim // 8)
entity_row_data[field_data.field_name] = [
field_data.vectors.binary_vector.data[start_pos:end_pos]]
field_data.vectors.binary_vector[start_pos:end_pos]]

return entity_row_data

0 comments on commit bdfcd7b

Please sign in to comment.