Skip to content

Commit

Permalink
add cuts as a string to Api class and agent's prompt to give more con…
Browse files Browse the repository at this point in the history
…text to the LLM
  • Loading branch information
alebjanes committed Jul 23, 2024
1 parent 5b2db57 commit d416fad
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 36 deletions.
25 changes: 16 additions & 9 deletions api/src/api_data_request/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
self.base_url = base_url
self.cube = table.name
self.cuts = {}
self.cuts_context = {}
self.drilldowns = set()
self.measures = set()
self.limit = None
Expand All @@ -60,10 +61,15 @@ def __init__(
if cuts:
cuts_processing(cuts, table, self)

def add_cut(self, key: str, value: str):
def add_cut(self, key: str, value: str, name: str):
if key not in self.cuts:
self.cuts[key] = set()
self.cuts_context[key] = set()
self.cuts[key].add(str(value))
self.cuts_context[key].add(str(name))

def format_cuts_context(self):
return ', '.join([f"{key} = {', '.join(values)}" for key, values in self.cuts_context.items()])

def add_drilldown(self, drilldown: Union[str, List[str]]):
"""
Expand Down Expand Up @@ -191,6 +197,7 @@ def fetch_data(self) -> Tuple[Dict[str, Any], pd.DataFrame, str]:
def __str__(self):
return self.build_api()


def cuts_processing(cuts: List[str], table: Table, api: ApiBuilder):
"""
Process cuts for the API request.
Expand Down Expand Up @@ -242,7 +249,7 @@ def cuts_processing(cuts: List[str], table: Table, api: ApiBuilder):

if year_range:
for year in range(min(year_range), max(year_range) + 1):
api.add_cut("Year", str(year))
api.add_cut("Year", str(year), str(year))

# Process other cuts
for cut in other_cuts:
Expand All @@ -257,20 +264,20 @@ def cuts_processing(cuts: List[str], table: Table, api: ApiBuilder):
var_levels = table.get_dimension_levels(var)

if var == "Year" or var == "Month" or var == "Quarter" or var == "Month and Year" or var == "Time":
api.add_cut(var, cut)
api.add_cut(var, cut, cut)
else:
drilldown_id, drilldown_name, s = get_similar_content(cut, table.name, var_levels)
drilldown_id, drilldown, s, drilldown_name = get_similar_content(cut, table.name, var_levels)

if drilldown_name != var:
if drilldown != var:
api.drilldowns.discard(var)
api.add_drilldown(drilldown_name)
api.add_drilldown(drilldown)

api.add_cut(drilldown_name, drilldown_id)
api.add_cut(drilldown, drilldown_id, drilldown_name)

for cut, values in api.cuts.items():
if len(values) > 1:
api.add_drilldown(cut)
elif "HS" in cut:
api.add_drilldown(cut)
# elif "HS" in cut:
# api.add_drilldown(cut)
else:
api.drilldowns.discard(cut)
26 changes: 14 additions & 12 deletions api/src/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time

from typing import Dict, Generator, Tuple
from typing import Dict, Tuple

from table_selection.table_selector import request_tables_to_lm_from_db
from table_selection.table import TableManager
Expand All @@ -13,12 +13,13 @@


def get_api(
natural_language_query: str,
token_tracker: Dict[str, Dict[str, int]] = None,
step: str = None,
form_json: Dict = None,
**kwargs,
) -> Tuple[str, Dict, str]:
natural_language_query: str,
token_tracker: Dict[str, Dict[str, int]] = None,
step: str = None,
form_json: Dict = None,
**kwargs
) -> Tuple[str, Dict, str]:

print("get_api")

if token_tracker is None:
Expand All @@ -41,12 +42,13 @@ def get_api(
variables, measures, cuts, token_tracker = get_api_params_from_lm(natural_language_query, kwargs["table"], token_tracker, model="gpt-4")
api = ApiBuilder(table=kwargs["table"], drilldowns=variables, measures=measures, cuts=cuts)
api_url = api.build_api()
cuts_context = api.format_cuts_context()
print("API:", api_url)
return get_api(
natural_language_query,
token_tracker,
step="fetch_data",
**{**kwargs, **{"api": api, "api_url": api_url}},
**{**kwargs, **{"api": api, "api_url": api_url, "cuts_context": cuts_context}},
)

elif step == "get_api_params_from_wrapper":
Expand Down Expand Up @@ -88,6 +90,7 @@ def get_api(
elif step == "agent_answer":
print("agent_answer")
api = kwargs["api"]
cuts_context = kwargs["cuts_context"]
variables = api.drilldowns
measures = api.measures
cuts = api.cuts
Expand Down Expand Up @@ -127,7 +130,7 @@ def get_api(
)
# insert_logs(table=table, values=values, log_type="apicall")
else:
kwargs["response"], token_tracker = agent_answer(kwargs["df"], natural_language_query, kwargs["api_url"], token_tracker)
kwargs["response"], token_tracker = agent_answer(df = kwargs["df"], natural_language_query = natural_language_query, api_url = kwargs["api_url"], context = cuts_context, token_tracker = token_tracker)
values.update(
{
"api_url": kwargs["api_url"],
Expand Down Expand Up @@ -163,7 +166,6 @@ def get_api(

if __name__ == "__main__":
get_api(
"What was the most imported product of USA in 2022?",
step="get_api_params_from_wrapper",
form_json=form_json,
"how much coffee did colombia export to the rest of the world in 2020?",
step="request_tables_to_lm_from_db",
)
33 changes: 19 additions & 14 deletions api/src/data_analysis/data_analysis.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from langchain_experimental.agents import create_pandas_dataframe_agent
from langchain_openai import ChatOpenAI
from pandas import DataFrame
from typing import Dict, List, Tuple
from typing import Dict, Tuple

from config import OPENAI_KEY
from data_analysis.token_counter import *
from data_analysis.token_counter import TokenTrackingHandler, get_openai_token_cost_for_model

cb = TokenTrackingHandler()
ALLOW_DANGEROUS_REQUEST = True


def agent_answer(
df: DataFrame, natural_language_query: str, api_url: str, token_tracker: Dict[str, Dict[str, int]] = None, model="gpt-4-turbo"
) -> Tuple[str, Dict[str, Dict[str, int]]]:
df: DataFrame,
natural_language_query: str,
api_url: str,
context: str,
token_tracker: Dict[str, Dict[str, int]] = None,
model="gpt-4-turbo"
) -> Tuple[str, Dict[str, Dict[str, int]]]:

"""
Answer the user's question based on the provided dataframe and additional information.
Expand All @@ -29,22 +34,22 @@ def agent_answer(
- An updated token_tracker dictionary with new token usage information.
"""
prompt = f"""
You are an expert data analyst working for the Observatory of Economic Complexity. Your goal is to provide an accurate and complete answer to the following user's question using the given dataframe.
You are an expert data analyst working for the Observatory of Economic Complexity. Your goal is to provide an accurate and complete answer to the following user's question using the data available.
User's Question:
{natural_language_query}
Take into consideration the data type and formatting of the columns. If a product, service, or other variable referred to by the user appears under a different name in the dataframe, explain this politely and provide an answer using the available data.
Take into consideration the data type and formatting of the columns. If a product, service, or other variable referred to by the user appears under a different name in the data, explain this politely and provide an answer using the available information.
If you cannot answer the question with the provided data, respond with "I can't answer your question with the available data."
You can complement your answer with any content found in the Observatory of Economic Complexity. Note that this dataframe was extracted using the following API (you can see the drilldowns, measures, and cuts applied to extract the data):
{api_url}
You can complement your answer with any content found in the Observatory of Economic Complexity. Note that this data was extracted with the following filters:
{context}
Guidelines:
1. Think through the answer step by step.
2. Avoid any comments unrelated to the question.
3. Always provide the corresponding trade value, and quantity if required.
4. All quantities are in metric tons, and trade value is in USD.
Think through the answer step by step.
Avoid any comments unrelated to the question.
Always provide the corresponding trade value, and quantity if required.
All quantities are in metric tons, and trade value is in USD.
"""

simple_prompt = f"""
Expand All @@ -63,7 +68,7 @@ def agent_answer(

llm = ChatOpenAI(model_name=model, temperature=0, openai_api_key=OPENAI_KEY, callbacks=[cb])
agent = create_pandas_dataframe_agent(
llm, df, verbose=True, agent_type="openai-tools", max_iterations=3, allow_dangerous_code=ALLOW_DANGEROUS_REQUEST
llm, df, verbose=True, agent_type="openai-tools", max_iterations=3, allow_dangerous_code=ALLOW_DANGEROUS_REQUEST, number_of_head_rows=df.shape[0]
)
response = agent.invoke(prompt)

Expand Down
3 changes: 2 additions & 1 deletion api/src/utils/similarity_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def get_similar_content(text, cube_name, drilldown_names, threshold=0, content_l
print(df)

drilldown_id = df.drilldown_id[0]
drilldown_name = df.drilldown_name[0]
drilldown = df.drilldown[0]
similarity = df.similarity[0]

return drilldown_id, drilldown, similarity
return drilldown_id, drilldown, similarity, drilldown_name


def get_similar_tables(vector, threshold=0, content_limit=1) -> List[str]:
Expand Down

0 comments on commit d416fad

Please sign in to comment.