Skip to content

Commit

Permalink
Merge pull request #386 from Navanit-git/main
Browse files Browse the repository at this point in the history
vllm support added
  • Loading branch information
zainhoda authored Apr 26, 2024
2 parents fcb69d6 + 4ad42ba commit f2cec1f
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 1 deletion.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ gemini = ["google-generativeai"]
marqo = ["marqo"]
zhipuai = ["zhipuai"]
qdrant = ["qdrant-client"]
vllm = ["vllm"]
1 change: 1 addition & 0 deletions src/vanna/vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .vllm import Vllm
76 changes: 76 additions & 0 deletions src/vanna/vllm/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import re

import requests

from ..base import VannaBase


class Vllm(VannaBase):
def __init__(self, config=None):
if config is None or "vllm_host" not in config:
self.host = "http://localhost:8000"
else:
self.host = config["vllm_host"]

if config is None or "model" not in config:
raise ValueError("check the config for vllm")
else:
self.model = config["model"]

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def extract_sql_query(self, text):
"""
Extracts the first SQL statement after the word 'select', ignoring case,
matches until the first semicolon, three backticks, or the end of the string,
and removes three backticks if they exist in the extracted string.
Args:
- text (str): The string to search within for an SQL statement.
Returns:
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
"""
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)

match = pattern.search(text)
if match:
# Remove three backticks from the matched string if they exist
return match.group(0).replace("```", "")
else:
return text

def generate_sql(self, question: str, **kwargs) -> str:
# Use the super generate_sql
sql = super().generate_sql(question, **kwargs)

# Replace "\_" with "_"
sql = sql.replace("\\_", "_")

sql = sql.replace("\\", "")

return self.extract_sql_query(sql)

def submit_prompt(self, prompt, **kwargs) -> str:
url = f"{self.host}/v1/chat/completions"
data = {
"model": self.model,
"stream": False,
"messages": prompt,
}

response = requests.post(url, json=data)

response_dict = response.json()

self.log(response.text)

return response_dict['choices'][0]['message']['content']
2 changes: 1 addition & 1 deletion tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def test_regular_imports():
from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat
from vanna.ZhipuAI.ZhipuAI_embeddings import ZhipuAI_Embeddings


def test_shortcut_imports():
from vanna.anthropic import Anthropic_Chat
from vanna.base import VannaBase
Expand All @@ -25,4 +24,5 @@ def test_shortcut_imports():
from vanna.ollama import Ollama
from vanna.openai import OpenAI_Chat, OpenAI_Embeddings
from vanna.vannadb import VannaDB_VectorStore
from vanna.vllm import Vllm
from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings

0 comments on commit f2cec1f

Please sign in to comment.