Skip to content

Commit

Permalink
Add support to embeddings from VertexAI
Browse files Browse the repository at this point in the history
  • Loading branch information
gquental committed Oct 15, 2024
1 parent c21d8bf commit bd7d55e
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions src/vanna/google/bigquery_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import os
import uuid
from typing import List, Optional
from vertexai.language_models import (
TextEmbeddingInput,
TextEmbeddingModel
)

import pandas as pd
from google.cloud import bigquery
Expand All @@ -23,17 +27,15 @@ def __init__(self, config: dict, **kwargs):
or set as an environment variable, assign it.
"""
print("Configuring genai")
self.type = "GEMINI"
import google.generativeai as genai

genai.configure(api_key=config["api_key"])

self.genai = genai
else:
self.type = "VERTEX_AI"
# Authenticate using VertexAI
from vertexai.language_models import (
TextEmbeddingInput,
TextEmbeddingModel,
)

if self.config.get("project_id"):
self.project_id = self.config.get("project_id")
Expand Down Expand Up @@ -139,25 +141,42 @@ def fetch_similar_training_data(self, training_data_type: str, question: str, n_
results = self.conn.query(query).result().to_dataframe()
return results

def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
result = self.genai.embed_content(
def get_embeddings(self, data: str, task: str) -> List[float]:
embeddings = None

if self.type == "VERTEX_AI":
input = [TextEmbeddingInput(data, task)]
model = TextEmbeddingModel.from_pretrained("text-embedding-004")

result = model.get_embeddings(input)

if len(result) > 0:
embeddings = result[0].values
else:
# Use Gemini Consumer API
result = self.genai.embed_content(
model="models/text-embedding-004",
content=data,
task_type="retrieval_query")
task_type=task)

if 'embedding' in result:
return result['embedding']
if 'embedding' in result:
embeddings = result['embedding']

return embeddings

def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
result = self.get_embeddings(data, "RETRIEVAL_QUERY")

if result != None:
return result
else:
raise ValueError("No embeddings returned")

def generate_storage_embedding(self, data: str, **kwargs) -> List[float]:
result = self.genai.embed_content(
model="models/text-embedding-004",
content=data,
task_type="retrieval_document")
result = self.get_embeddings(data, "RETRIEVAL_DOCUMENT")

if 'embedding' in result:
return result['embedding']
if result != None:
return result
else:
raise ValueError("No embeddings returned")

Expand Down

0 comments on commit bd7d55e

Please sign in to comment.