Skip to content

Commit

Permalink
Merge pull request #277 from kun321/main
Browse files Browse the repository at this point in the history
add support for MySQL
  • Loading branch information
zainhoda authored Mar 6, 2024
2 parents 6afef33 + 6902159 commit 217b1ca
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ dependencies = [

[project.optional-dependencies]
postgres = ["psycopg2-binary", "db-dtypes"]
mysql = ["PyMySQL"]
bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
all = ["psycopg2-binary", "db-dtypes", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand Down
85 changes: 85 additions & 0 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,91 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
self.run_sql_is_set = True
self.run_sql = run_sql_postgres


def connect_to_mysql(
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
):

try:
import pymysql.cursors
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: \npip install PyMySQL"
)

if not host:
host = os.getenv("HOST")

if not host:
raise ImproperlyConfigured("Please set your MySQL host")

if not dbname:
dbname = os.getenv("DATABASE")

if not dbname:
raise ImproperlyConfigured("Please set your MySQL database")

if not user:
user = os.getenv("USER")

if not user:
raise ImproperlyConfigured("Please set your MySQL user")

if not password:
password = os.getenv("PASSWORD")

if not password:
raise ImproperlyConfigured("Please set your MySQL password")

if not port:
port = os.getenv("PORT")

if not port:
raise ImproperlyConfigured("Please set your MySQL port")

conn = None

try:
conn = pymysql.connect(host=host,
user=user,
password=password,
database=dbname,
port=port,
cursorclass=pymysql.cursors.DictCursor)
except pymysql.Error as e:
raise ValidationError(e)

def run_sql_mysql(sql: str) -> Union[pd.DataFrame, None]:
if conn:
try:
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()

# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
return df

except pymysql.Error as e:
conn.rollback()
raise ValidationError(e)

except Exception as e:
conn.rollback()
raise e

self.run_sql_is_set = True
self.run_sql = run_sql_mysql


def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None):
"""
Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql]
Expand Down

0 comments on commit 217b1ca

Please sign in to comment.