Skip to content

Commit

Permalink
fix: use correct primary field name in Hits (#2559)
Browse files Browse the repository at this point in the history
issue: #2558

Signed-off-by: zhenshan.cao <[email protected]>
  • Loading branch information
czs007 authored Jan 13, 2025
1 parent 3b236f0 commit c07e656
Show file tree
Hide file tree
Showing 10 changed files with 539 additions and 448 deletions.
62 changes: 62 additions & 0 deletions examples/customize_schema_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import time
import numpy as np
from pymilvus import (
MilvusClient,
DataType
)

fmt = "\n=== {:30} ===\n"
dim = 8
collection_name = "hello_milvus"
milvus_client = MilvusClient("http://localhost:19530")

has_collection = milvus_client.has_collection(collection_name, timeout=5)
if has_collection:
milvus_client.drop_collection(collection_name)

schema = milvus_client.create_schema(enable_dynamic_field=True)
schema.add_field("uid", DataType.INT64, is_primary=True)
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
schema.add_field("title", DataType.VARCHAR, max_length=64)
schema.add_field("id", DataType.VARCHAR, max_length=64)


index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name = "embeddings", metric_type="L2")
milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong")

print(fmt.format(" all collections "))
print(milvus_client.list_collections())

print(fmt.format(f"schema of collection {collection_name}"))
print(milvus_client.describe_collection(collection_name))

rng = np.random.default_rng(seed=19530)
rows = [
{"uid": 1, "embeddings": rng.random((1, dim))[0], "a": 100, "title": "t1", "id":"u1"},
{"uid": 2, "embeddings": rng.random((1, dim))[0], "b": 200, "title": "t2", "id":"u2"},
{"uid": 3, "embeddings": rng.random((1, dim))[0], "c": 300, "title": "t3", "id":"u3"},
{"uid": 4, "embeddings": rng.random((1, dim))[0], "d": 400, "title": "t4", "id":"u4"},
{"uid": 5, "embeddings": rng.random((1, dim))[0], "e": 500, "title": "t5", "id":"u5"},
{"uid": 6, "embeddings": rng.random((1, dim))[0], "f": 600, "title": "t6", "id":"u6"},
]

print(fmt.format("Start inserting entities"))
insert_result = milvus_client.insert(collection_name, rows)
print(fmt.format("Inserting entities done"))
print(insert_result)


print(fmt.format("Start load collection "))
milvus_client.load_collection(collection_name)

rng = np.random.default_rng(seed=19530)
vectors_to_search = rng.random((1, dim))

print(fmt.format(f"Start search with retrieve serveral fields."))
result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["id"])
for hits in result:
for hit in hits:
print(f"hit: {hit}")

milvus_client.drop_collection(collection_name)
21 changes: 16 additions & 5 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def __init__(
self._nq = res.num_queries
all_topks = res.topks
self.recalls = res.recalls
self._pk_name = res.primary_field_name or "id"

self.cost = int(status.extra_info["report_value"] if status and status.extra_info else "0")

Expand All @@ -500,7 +501,14 @@ def __init__(
start, end = nq_thres, nq_thres + topk
nq_th_fields = self.get_fields_by_range(start, end, fields_data)
data.append(
Hits(topk, all_pks[start:end], all_scores[start:end], nq_th_fields, output_fields)
Hits(
topk,
all_pks[start:end],
all_scores[start:end],
nq_th_fields,
output_fields,
self._pk_name,
)
)
nq_thres += topk
self._session_ts = session_ts
Expand Down Expand Up @@ -673,6 +681,7 @@ def __init__(
distances: List[float],
fields: Dict[str, Tuple[List[Any], schema_pb2.FieldData]],
output_fields: List[str],
pk_name: str,
):
"""
Args:
Expand All @@ -681,6 +690,7 @@ def __init__(
"""
self.ids = pks
self.distances = distances
self._pk_name = pk_name

all_fields = list(fields.keys())
dynamic_fields = list(set(output_fields) - set(all_fields))
Expand Down Expand Up @@ -719,7 +729,7 @@ def __init__(
# sparse float vector and other fields
curr_field[fname] = data[i]

hits.append(Hit(pks[i], distances[i], curr_field))
hits.append(Hit(pks[i], distances[i], curr_field, self._pk_name))

super().__init__(hits)

Expand All @@ -739,10 +749,11 @@ class Hit:
distance: float
fields: Dict[str, Any]

def __init__(self, pk: Union[int, str], distance: float, fields: Dict[str, Any]):
def __init__(self, pk: Union[int, str], distance: float, fields: Dict[str, Any], pk_name: str):
self.id = pk
self.distance = distance
self.fields = fields
self._pk_name = pk_name

def __getattr__(self, item: str):
if item not in self.fields:
Expand All @@ -765,13 +776,13 @@ def get(self, field_name: str) -> Any:
return self.fields.get(field_name)

def __str__(self) -> str:
return f"id: {self.id}, distance: {self.distance}, entity: {self.fields}"
return f"{self._pk_name}: {self.id}, distance: {self.distance}, entity: {self.fields}"

__repr__ = __str__

def to_dict(self):
return {
"id": self.id,
self._pk_name: self.id,
"distance": self.distance,
"entity": self.fields,
}
Expand Down
Loading

0 comments on commit c07e656

Please sign in to comment.