Skip to content

Commit

Permalink
Add a get_embedding_sources composite method.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704357725
  • Loading branch information
sdenton4 authored and copybara-github committed Dec 9, 2024
1 parent b156d6a commit 025f7f4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
6 changes: 6 additions & 0 deletions hoplite/db/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ def get_embeddings(
embeddings = [self.get_embedding(int(idx)) for idx in embedding_ids]
return embedding_ids, np.array(embeddings)

def get_embedding_sources(
self, embedding_ids: np.ndarray
) -> tuple[EmbeddingSource, ...]:
"""Get an array of embedding sources for the indicated IDs."""
return tuple(self.get_embedding_source(int(idx)) for idx in embedding_ids)

def random_batched_iterator(
self,
ids: np.ndarray,
Expand Down
2 changes: 2 additions & 0 deletions hoplite/db/tests/hoplite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def test_graph_db_interface(self, db_type, thread_split):
embs = db.get_embeddings_by_source(source.dataset_name, 'fake_id', None)
self.assertEqual(embs.shape[0], 0)

sources = db.get_embedding_sources(idxes[:3])
self.assertLen(sources, 3)
db.commit()

@parameterized.product(db_type=PERSISTENT_DB_TYPES)
Expand Down

0 comments on commit 025f7f4

Please sign in to comment.