diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c454b524..7f3acda9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,6 +27,7 @@ jobs: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }} SNOWFLAKE_USERNAME: ${{ secrets.SNOWFLAKE_USERNAME }} SNOWFLAKE_PASSWORD: ${{ secrets.SNOWFLAKE_PASSWORD }} diff --git a/pyproject.toml b/pyproject.toml index dcc09dc8..03da8c41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,8 @@ mysql = ["PyMySQL"] bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] duckdb = ["duckdb"] -all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo"] +google = ["google-generativeai", "google-cloud-aiplatform"] +all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform"] test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] diff --git a/src/vanna/google/__init__.py b/src/vanna/google/__init__.py new file mode 100644 index 00000000..b0592623 --- /dev/null +++ b/src/vanna/google/__init__.py @@ -0,0 +1 @@ +from .gemini_chat import GoogleGeminiChat \ No newline at end of file diff --git a/src/vanna/google/gemini_chat.py b/src/vanna/google/gemini_chat.py new file mode 100644 index 00000000..2a857f00 --- /dev/null +++ b/src/vanna/google/gemini_chat.py @@ -0,0 +1,52 @@ +import os +from ..base import VannaBase + + +class GoogleGeminiChat(VannaBase): + def __init__(self, config=None): + VannaBase.__init__(self, config=config) + + # default temperature - can be overrided using config + self.temperature = 0.7 + + if "temperature" in config: + self.temperature = config["temperature"] + + if "model_name" in config: + model_name = config["model_name"] + else: + model_name = "gemini-1.0-pro" + + self.google_api_key = None + + if "api_key" in config or os.getenv("GOOGLE_API_KEY"): + """ + If Google api_key is provided through config + or set as an environment variable, assign it. + """ + import google.generativeai as genai + + genai.configure(api_key=config["api_key"]) + self.chat_model = genai.GenerativeModel(model_name) + else: + # Authenticate using VertexAI + from vertexai.preview.generative_models import GenerativeModel + self.chat_model = GenerativeModel("gemini-pro") + + def system_message(self, message: str) -> any: + return message + + def user_message(self, message: str) -> any: + return message + + def assistant_message(self, message: str) -> any: + return message + + def submit_prompt(self, prompt, **kwargs) -> str: + response = self.chat_model.generate_content( + prompt, + generation_config={ + "temperature": self.temperature, + }, + ) + return response.text diff --git a/tests/test_vanna.py b/tests/test_vanna.py index 1d11c3f3..82b20195 100644 --- a/tests/test_vanna.py +++ b/tests/test_vanna.py @@ -1,6 +1,7 @@ import os from vanna.anthropic.anthropic_chat import Anthropic_Chat +from vanna.google import GoogleGeminiChat from vanna.mistral.mistral import Mistral from vanna.openai.openai_chat import OpenAI_Chat from vanna.remote import VannaDefault @@ -92,9 +93,22 @@ def __init__(self, config=None): def test_vn_claude(): - sql = vn_claude.generate_sql("What are the top 5 customers by sales?") + sql = vn_claude.generate_sql("What are the top 8 customers by sales?") df = vn_claude.run_sql(sql) - assert len(df) == 5 + assert len(df) == 8 + +class VannaGemini(VannaDB_VectorStore, GoogleGeminiChat): + def __init__(self, config=None): + VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) + GoogleGeminiChat.__init__(self, config=config) + +vn_gemini = VannaGemini(config={'api_key': os.environ['GEMINI_API_KEY']}) +vn_gemini.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') + +def test_vn_gemini(): + sql = vn_gemini.generate_sql("What are the top 9 customers by sales?") + df = vn_gemini.run_sql(sql) + assert len(df) == 9 def test_training_plan(): vn_dummy = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY)