Skip to content

Commit

Permalink
Merge pull request #349 from scalabreseGD/distance-function-chromadb
Browse files Browse the repository at this point in the history
  • Loading branch information
zainhoda authored Apr 10, 2024
2 parents 7da87cc + 70b2aa2 commit fe2d439
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions src/vanna/chromadb/chromadb_vector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import uuid
from typing import List

import chromadb
Expand All @@ -16,17 +15,14 @@
class ChromaDB_VectorStore(VannaBase):
def __init__(self, config=None):
VannaBase.__init__(self, config=config)
if config is None:
config = {}

if config is not None:
path = config.get("path", ".")
self.embedding_function = config.get("embedding_function", default_ef)
curr_client = config.get("client", "persistent")
self.n_results = config.get("n_results", 10)
else:
path = "."
self.embedding_function = default_ef
curr_client = "persistent" # defaults to persistent storage
self.n_results = 10 # defaults to 10 documents
path = config.get("path", ".")
self.embedding_function = config.get("embedding_function", default_ef)
curr_client = config.get("client", "persistent")
collection_metadata = config.get("collection_metadata", None)
self.n_results = config.get("n_results", 10)

if curr_client == "persistent":
self.chroma_client = chromadb.PersistentClient(
Expand All @@ -43,13 +39,19 @@ def __init__(self, config=None):
raise ValueError(f"Unsupported client was set in config: {curr_client}")

self.documentation_collection = self.chroma_client.get_or_create_collection(
name="documentation", embedding_function=self.embedding_function
name="documentation",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
self.ddl_collection = self.chroma_client.get_or_create_collection(
name="ddl", embedding_function=self.embedding_function
name="ddl",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
self.sql_collection = self.chroma_client.get_or_create_collection(
name="sql", embedding_function=self.embedding_function
name="sql",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)

def generate_embedding(self, data: str, **kwargs) -> List[float]:
Expand Down

0 comments on commit fe2d439

Please sign in to comment.