Skip to content

Commit

Permalink
Changes to make this compatible with running purely on a GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv authored Dec 8, 2023
1 parent a71bf14 commit 27b5eb4
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions utils/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
if os.getenv("TOKENIZERS_PARALLELISM") is None:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu")
encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
nlp = spacy.load("en_core_web_sm")


Expand All @@ -23,7 +23,7 @@ def knn(
"""
Get top most similar columns' embeddings to query using cosine similarity.
"""
query_emb = encoder.encode(query, convert_to_tensor=True, device="cpu").unsqueeze(0)
query_emb = encoder.encode(query, convert_to_tensor=True).unsqueeze(0)
similarity_scores = F.cosine_similarity(query_emb, all_emb)
top_results = torch.nonzero(similarity_scores > threshold).squeeze()
# if top_results is empty, return empty tensors
Expand Down Expand Up @@ -95,6 +95,7 @@ def get_md_emb(
3. Generate the metadata string using the column info so far.
4. Get joinable columns between tables in topk_table_columns and add to final metadata string.
"""
column_emb = column_emb.to("cuda")
# 1) get top k columns
top_k_scores, top_k_indices = knn(question, column_emb, k, threshold)
topk_table_columns = {}
Expand Down

0 comments on commit 27b5eb4

Please sign in to comment.