Skip to content

Commit

Permalink
Merge pull request #5 from Dnouv/fix_task_exec
Browse files Browse the repository at this point in the history
Merge changes from Deb's fork
  • Loading branch information
Dnouv authored Jun 3, 2024
2 parents abf872e + 301a8d4 commit 5beceb3
Show file tree
Hide file tree
Showing 21 changed files with 283 additions and 123 deletions.
14 changes: 12 additions & 2 deletions .github/workflows/main-release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
push:
branches:
- main
- develop

jobs:
build-and-push:
Expand All @@ -32,13 +33,22 @@ jobs:
- name: Expose GH Runtime
uses: crazy-max/ghaction-github-runtime@v3

- name: set lower case owner name
run: |
echo "OWNER_LC=${OWNER,,}" >>${GITHUB_ENV}
env:
OWNER: "${{ github.repository_owner }}"

- name: set lower case owner name
run: |
echo "REPO_LC=$(echo ${{ github.repository }} | awk 'BEGIN{FS=OFS="/"}{print tolower($1) "/" $2}')" >>${GITHUB_ENV}
- name: Build and Push Docker Images
run: |
make build_and_push_images
env:
REGISTRY: "ghcr.io"
ORG: ${{ github.repository_owner }}
REPO: ${{ github.event.repository.name }}
ORG: ${{ env.OWNER_LC }}
REPO: ${{ env.REPO_LC }}
GITHUB_WORKFLOW: ${{ github.workflow }}

build-tauri:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tag-release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:

- name: Build and Push Docker Images
run: |
TAG=${GITHUB_REF#refs/tags/} make build_and_push_images
TAG=${GITHUB_REF#refs/heads/} make build_and_push_images
env:
REGISTRY: ghcr.io
ORG: ${{ github.repository_owner }}
Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:
- name: get release version
id: get_release_version
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
run: echo "TAG=${GITHUB_REF#refs/heads/}" >> $GITHUB_ENV

- name: get release id
id: get_release_id
Expand Down
1 change: 1 addition & 0 deletions core/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .config import *
84 changes: 84 additions & 0 deletions core/config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from os import getenv

def get_mongo_database_name():
return getenv("MONGODB_DATABASE", "rubra_db")

def get_mongo_url() -> str:
url = getenv("MONGODB_URL")
if url:
return url

host = getenv("MONGODB_HOST", "localhost")
user = getenv("MONGODB_USER", getenv("MONGODB_USERNAME", None))
password = getenv("MONGODB_PASS", getenv("MONGODB_PASSWORD", None))
port = getenv("MONGODB_PORT", 27017)
database = get_mongo_database_name()

if user and not password:
print("MONGODB_USER set but password not found, ignoring user")

if not user and password:
print("MONGODB_PASSWORD set but user not found, ignoring password")

if user and password:
return f"mongodb://{user}:{password}@{host}:${port}/{database}"

return f"mongodb://{host}:{port}/{database}"

def get_redis_url() -> str:
url = getenv("REDIS_URL")
if url:
return url

host = getenv("REDIS_HOST", "localhost")
password = getenv("REDIS_PASS", getenv("REDIS_PASSWORD", None))
user = getenv("REDIS_USER", getenv("REDIS_USERNAME", None))
port = getenv("REDIS_PORT", 6379)
database = getenv("REDIS_DATABASE", 0)

if password:
return f"redis://{user or ''}:{password}@{host}:{port}/{database}"

return f"redis://{host}:{port}/{database}"

def get_litellm_url() -> str:
url = getenv("LITELLM_URL")
if url:
return url

host = getenv("LITELLM_HOST", "localhost")
port = getenv("LITELLM_PORT", 8002)

return f"http://{host}:{port}"

def get_vector_db_url() -> str:
url = getenv("VECTOR_DB_URL")
if url:
return url

host = getenv("VECTOR_DB_HOST", "localhost")
port = getenv("VECTOR_DB_PORT", 8010)

return f"http://{host}:{port}"

def get_embedding_url():
url = getenv("EMBEDDING_URL")
if url:
return url

host = getenv("EMBEDDING_HOST", "localhost")
port = getenv("EMBEDDING_PORT", 8020)

return f"http://{host}:{port}"

mongo_database = get_mongo_database_name()

mongo_url = get_mongo_url()

litellm_url = get_litellm_url()

vector_db_url = get_vector_db_url()

redis_url = get_redis_url()

embedding_url = get_embedding_url()
2 changes: 1 addition & 1 deletion core/local_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def simple_qa(query: str, context: str) -> str:
temperature=0.1,
messages=messages,
stream=False,
response_format="web",
response_format={"type": "text"}, # mlc doesn't supports string "web"
)
return response.choices[0].message.content

Expand Down
Empty file added core/tasks/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions core/tasks/celery_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import redis
from celery import Celery
from typing import cast
import core.config as configs


redis_client = cast(redis.Redis, redis.Redis.from_url(configs.redis_url)) # annoyingly from_url returns None, not Self
app = Celery("tasks", broker=configs.redis_url)

app.autodiscover_tasks(["core.tasks"]) # Explicitly discover tasks in 'app' package
7 changes: 0 additions & 7 deletions core/tasks/celery_config.py

This file was deleted.

38 changes: 38 additions & 0 deletions core/tasks/is_ready.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import socket

import requests

from core.config import litellm_url, vector_db_url

from .celery_app import app

def is_ready():
# response = requests.get(f"{litellm_url}/health", headers={
# "Authorization": f"Bearer {os.getenv('LITELLM_MASTER_KEY', '')}"
# })
# if not response.ok:
# raise Exception(response.text)

# print(response)

response = requests.get(f"{litellm_url}/health/readiness")
if not response.ok:
raise Exception(response.text)

print(response)

pong = app.control.ping([f'celery@{socket.gethostname()}'])
if len(pong) == 0 or list(pong[0].values())[0].get('ok', None) is None:
raise Exception('ping failed with' + str(pong))

print(pong)

response = requests.get(f"{vector_db_url}/healthz")
if not response.ok:
raise Exception(response.text)

print(response)


if __name__ == "__main__":
is_ready()
58 changes: 28 additions & 30 deletions core/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
import sys
from functools import partial

from typing import cast

# Third Party
from core.tools.knowledge.vector_db.milvus.operations import add_texts, milvus_connection_alias
from langchain.text_splitter import RecursiveCharacterTextSplitter
from core.tools.knowledge.file_knowledge_tool import FileKnowledgeTool
from core.tools.web_browse.web_browse_tool import WebBrowseTool

from pymilvus import connections

# Get the current working directory
current_directory = os.getcwd()

Expand All @@ -31,33 +41,29 @@
from openai import OpenAI
from pymongo import MongoClient

litellm_host = os.getenv("LITELLM_HOST", "localhost")
redis_host = os.getenv("REDIS_HOST", "localhost")
mongodb_host = os.getenv("MONGODB_HOST", "localhost")

redis_client = redis.Redis(host=redis_host, port=6379, db=0)
app = Celery("tasks", broker=f"redis://{redis_host}:6379/0")
app.config_from_object("core.tasks.celery_config")
app.autodiscover_tasks(["core.tasks"]) # Explicitly discover tasks in 'app' package
import core.config as configs

# MongoDB Configuration
MONGODB_URL = f"mongodb://{mongodb_host}:27017"
DATABASE_NAME = "rubra_db"
from .is_ready import is_ready
from .celery_app import app

# Global MongoDB client
mongo_client = None
mongo_client: MongoClient = None

redis_client = cast(redis.Redis, redis.Redis.from_url(configs.redis_url)) # annoyingly from_url returns None, not Self

@signals.worker_process_init.connect
def setup_mongo_connection(*args, **kwargs):
def ensure_connections(*args, **kwargs):
global mongo_client
mongo_client = MongoClient(f"mongodb://{mongodb_host}:27017")
mongo_client = MongoClient(configs.mongo_url)

mongo_client.admin.command('ping')

is_ready()

def create_assistant_message(
thread_id, assistant_id, run_id, content_text, role=Role7.assistant.value
):
db = mongo_client[DATABASE_NAME]
db = mongo_client[configs.mongo_database]

# Generate a unique ID for the message
message_id = f"msg_{uuid.uuid4().hex[:6]}"
Expand Down Expand Up @@ -175,10 +181,6 @@ def rubra_local_agent_chat_completion(


def form_openai_tools(tools, assistant_id: str):
# Third Party
from core.tools.knowledge.file_knowledge_tool import FileKnowledgeTool
from core.tools.web_browse.web_browse_tool import WebBrowseTool

retrieval = FileKnowledgeTool()
googlesearch = WebBrowseTool()
res_tools = []
Expand Down Expand Up @@ -215,11 +217,12 @@ def form_openai_tools(tools, assistant_id: str):
@shared_task
def execute_chat_completion(assistant_id, thread_id, redis_channel, run_id):
try:
db = mongo_client[configs.mongo_database] # OpenAI call can fail, so we need to get the db again

oai_client = OpenAI(
base_url=f"http://{litellm_host}:8002/v1/",
api_key="abc", # point to litellm server
base_url=configs.litellm_url,
api_key=os.getenv("LITELLM_MASTER_KEY"), # point to litellm server
)
db = mongo_client[DATABASE_NAME]

# Fetch assistant and thread messages synchronously
assistant = db.assistants.find_one({"id": assistant_id})
Expand Down Expand Up @@ -453,15 +456,10 @@ def execute_chat_completion(assistant_id, thread_id, redis_channel, run_id):

@app.task
def execute_asst_file_create(file_id: str, assistant_id: str):
# Standard Library
import json

# Third Party
from core.tools.knowledge.vector_db.milvus.operations import add_texts
from langchain.text_splitter import RecursiveCharacterTextSplitter

try:
db = mongo_client[DATABASE_NAME]
if mongo_client is None:
raise Exception("MongoDB client not initialized yet")
db = mongo_client[configs.mongo_database]
collection_name = assistant_id
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
parsed_text = ""
Expand Down
9 changes: 4 additions & 5 deletions core/tools/knowledge/file_knowledge_tool.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Standard Library
import json
import os

import core.config as configs

# Third Party
import requests

VECTOR_DB_HOST = os.getenv("VECTOR_DB_HOST", "localhost")
VECTOR_DB_MATCH_URL = f"http://{VECTOR_DB_HOST}:8010/similarity_match"

vector_db_url = f"{configs.vector_db_url}/similarity_match"

class FileKnowledgeTool:
name = "FileKnowledge"
Expand Down Expand Up @@ -42,7 +41,7 @@ def file_knowledge_search_api(query: str, assistant_id: str):
}
)

response = requests.post(VECTOR_DB_MATCH_URL, headers=headers, data=data)
response = requests.post(vector_db_url, headers=headers, data=data)
res = response.json()["response"]
txt = ""
for r in res:
Expand Down
6 changes: 3 additions & 3 deletions core/tools/knowledge/vector_db/milvus/custom_embeddigs.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Standard Library
import json
import os
from typing import List

# Third Party
import requests
from langchain.embeddings.base import Embeddings

HOST = os.getenv("EMBEDDING_HOST", "localhost")
EMBEDDING_URL = f"http://{HOST}:8020/embed_multiple"
import core.config as configs

EMBEDDING_URL = f"{configs.embedding_url}/embed_multiple"


def embed_text(texts: List[str]) -> List[List[float]]:
Expand Down
16 changes: 7 additions & 9 deletions core/tools/knowledge/vector_db/milvus/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
from .custom_embeddigs import CustomEmbeddings
from .query_milvus import Milvus

MILVUS_HOST = os.getenv("MILVUS_HOST", "localhost")

model = {}
top_re_rank = 5
top_k_match = 10

milvus_connection_alias = Milvus.create_connection_alias({
"host": os.getenv("MILVUS_HOST", "localhost"),
"port": os.getenv("MILVUS_PORT", "19530"),
"user": os.getenv("MILVUS_USER", os.getenv("MILVUS_USERNAME", "")),
"password": os.getenv("MILVUS_PASS", os.getenv("MILVUS_PASSWORD", ""))
})

class Query(BaseModel):
text: str
Expand All @@ -27,17 +31,11 @@ class Query(BaseModel):
def drop_collection(collection_name: str):
load_collection(collection_name).drop_collection()


def load_collection(collection_name: str) -> Milvus:
return Milvus(
embedding_function=CustomEmbeddings(),
collection_name=collection_name,
connection_args={
"host": MILVUS_HOST,
"port": "19530",
"user": "username",
"password": "password",
},
alias=milvus_connection_alias,
index_params={
"metric_type": "IP",
"index_type": "FLAT",
Expand Down
Loading

0 comments on commit 5beceb3

Please sign in to comment.