diff --git a/src/vanna/google/bigquery_vector.py b/src/vanna/google/bigquery_vector.py index 09cbf391..df68835d 100644 --- a/src/vanna/google/bigquery_vector.py +++ b/src/vanna/google/bigquery_vector.py @@ -2,6 +2,10 @@ import os import uuid from typing import List, Optional +from vertexai.language_models import ( + TextEmbeddingInput, + TextEmbeddingModel +) import pandas as pd from google.cloud import bigquery @@ -23,17 +27,15 @@ def __init__(self, config: dict, **kwargs): or set as an environment variable, assign it. """ print("Configuring genai") + self.type = "GEMINI" import google.generativeai as genai genai.configure(api_key=config["api_key"]) self.genai = genai else: + self.type = "VERTEX_AI" # Authenticate using VertexAI - from vertexai.language_models import ( - TextEmbeddingInput, - TextEmbeddingModel, - ) if self.config.get("project_id"): self.project_id = self.config.get("project_id") @@ -139,25 +141,42 @@ def fetch_similar_training_data(self, training_data_type: str, question: str, n_ results = self.conn.query(query).result().to_dataframe() return results - def generate_question_embedding(self, data: str, **kwargs) -> List[float]: - result = self.genai.embed_content( + def get_embeddings(self, data: str, task: str) -> List[float]: + embeddings = None + + if self.type == "VERTEX_AI": + input = [TextEmbeddingInput(data, task)] + model = TextEmbeddingModel.from_pretrained("text-embedding-004") + + result = model.get_embeddings(input) + + if len(result) > 0: + embeddings = result[0].values + else: + # Use Gemini Consumer API + result = self.genai.embed_content( model="models/text-embedding-004", content=data, - task_type="retrieval_query") + task_type=task) - if 'embedding' in result: - return result['embedding'] + if 'embedding' in result: + embeddings = result['embedding'] + + return embeddings + + def generate_question_embedding(self, data: str, **kwargs) -> List[float]: + result = self.get_embeddings(data, "RETRIEVAL_QUERY") + + if result != None: + return result else: raise ValueError("No embeddings returned") def generate_storage_embedding(self, data: str, **kwargs) -> List[float]: - result = self.genai.embed_content( - model="models/text-embedding-004", - content=data, - task_type="retrieval_document") + result = self.get_embeddings(data, "RETRIEVAL_DOCUMENT") - if 'embedding' in result: - return result['embedding'] + if result != None: + return result else: raise ValueError("No embeddings returned")