Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
pippo-sci committed Jul 22, 2024
2 parents bde1d6f + 6438595 commit 54e5ee6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 30 deletions.
12 changes: 6 additions & 6 deletions api/setup/load_drilldowns_to_db.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import json
import pandas as pd
import requests
import urllib.parse

from config import POSTGRES_ENGINE, SCHEMA_DRILLDOWNS, TESSERACT_API, TABLES_PATH
from utils.similarity_search import embedding
from sqlalchemy import text as sql_text

embedding_model = "sfr-embedding-mistral:q8_0"
embedding_size = 4096
DRILLDOWNS_TABLE_NAME = "drilldowns_sfr"
embedding_model = "multi-qa-mpnet-base-cos-v1"
embedding_size = 768
DRILLDOWNS_TABLE_NAME = "drilldowns"

def create_table(table_name=DRILLDOWNS_TABLE_NAME, schema_name=SCHEMA_DRILLDOWNS, embedding_size=embedding_size):
query_schema = f"CREATE SCHEMA IF NOT EXISTS {schema_name}"
Expand Down Expand Up @@ -71,7 +70,8 @@ def main(include_cubes=False):
drilldown_unique_name = level.get('unique_name')
api_url = f"{TESSERACT_API}data.jsonrecords?cube={cube_name}&drilldowns={level['unique_name'] if drilldown_unique_name is not None else drilldown_name}&measures={measure}"
load_data_to_db(api_url, measure, cube_name, drilldown_name, drilldown_unique_name)
else: pass
else:
pass

else:
create_table()
Expand All @@ -90,5 +90,5 @@ def main(include_cubes=False):
load_data_to_db(api_url, measure, cube_name, drilldown_name, drilldown_unique_name)

if __name__ == "__main__":
include_cubes = ['trade_i_baci_a_96'] # if set to False it will upload the drilldowns of all cubes in the schema.json
include_cubes = ['trade_i_baci_a_92', 'trade_i_baci_a_22'] # if set to False it will upload the drilldowns of all cubes in the schema.json
main(include_cubes)
34 changes: 10 additions & 24 deletions api/src/utils/similarity_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,26 @@

from config import POSTGRES_ENGINE, OLLAMA_API

def get_similar_content(text, cube_name, drilldown_names, threshold=0, content_limit=1, embedding_model='multi-qa-MiniLM-L6-cos-v1', verbose=False, table_name = ""):
def get_similar_content(text, cube_name, drilldown_names, threshold=0, content_limit=1, embedding_model='multi-qa-mpnet-base-cos-v1', verbose=False):
"""
Receives a string, computes its embedding, and then looks for similar content in a database based on the given cube and drilldown levels.
Returns top match, similarity score, and others depending on the drilldown.
"""
drilldown_names_array = "{" + ",".join(map(lambda x: f'"{x}"', drilldown_names)) + "}"

if embedding_model == 'multi-qa-MiniLM-L6-cos-v1':
model = SentenceTransformer(embedding_model) # 384
embedding = model.encode([text])
query = """select drilldown_id, drilldown_name, drilldown, similarity from "match_drilldowns"('{}','{}' ,'{}','{}','{}'); """.format(embedding[0].tolist().__str__(), str(threshold), str(content_limit), str(cube_name), drilldown_names_array)
embedding_column_name = {
'multi-qa-mpnet-base-cos-v1': 'embedding' #768 dimensions
}

elif embedding_model == 'all-mpnet-base-v2' or embedding_model == 'all-MiniLM-L12-v2' or embedding_model == 'multi-qa-mpnet-base-cos-v1':
model = SentenceTransformer(embedding_model) # 384
embedding = model.encode([text])

query = """select drilldown_id, drilldown_name, drilldown, similarity from "match_drilldowns_new"('{}','{}' ,'{}','{}','{}', '{}'); """.format(embedding[0].tolist().__str__(), str(threshold), str(content_limit), str(cube_name), drilldown_names_array, str(table_name))
drilldown_names_array = "{" + ",".join(map(lambda x: f'"{x}"', drilldown_names)) + "}"

else:
url = "{}embeddings".format(OLLAMA_API)
payload = {
"model": embedding_model,
"prompt": text
}

response = requests.post(url, json = payload)
embeddings_json = json.loads(response.text)
embedding = embeddings_json['embedding']
query = """select drilldown_id, drilldown_name, drilldown, similarity from "match_drilldowns_new"('{}','{}' ,'{}','{}','{}', '{}'); """.format(embedding.__str__(), str(threshold), str(content_limit), str(cube_name), drilldown_names_array, str(table_name))
model = SentenceTransformer(embedding_model)
embedding = model.encode([text])
query = """select drilldown_id, drilldown_name, drilldown, similarity from "match_drilldowns"('{}','{}' ,'{}','{}','{}', '{}'); """.format(embedding[0].tolist().__str__(), str(threshold), str(content_limit), str(cube_name), drilldown_names_array, embedding_column_name[embedding_model])

#df = pd.read_sql(query,con=POSTGRES_ENGINE)
df = pd.read_sql_query(sql_text(query), POSTGRES_ENGINE.connect())

if verbose: print(df)
if verbose:
print(df)

drilldown_id = df.drilldown_id[0]
drilldown_name = df.drilldown_name[0]
Expand Down

0 comments on commit 54e5ee6

Please sign in to comment.