Skip to content

Commit

Permalink
Linted code
Browse files Browse the repository at this point in the history
  • Loading branch information
yellowcap committed Nov 29, 2024
1 parent 2271ed9 commit 4efc4a7
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 117 deletions.
7 changes: 3 additions & 4 deletions app/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
import operator
from typing import Annotated, List, TypedDict

from dotenv import load_dotenv

_ = load_dotenv()

import folium
import streamlit as st
from dotenv import load_dotenv
from langchain_anthropic import ChatAnthropic
from langchain_chroma import Chroma
from langchain_core.messages import HumanMessage
from langchain_ollama.embeddings import OllamaEmbeddings
from langgraph.graph import StateGraph
from streamlit_folium import st_folium

_ = load_dotenv()


def make_context(docs):
fmt_docs = []
Expand Down
7 changes: 5 additions & 2 deletions nbs/02_glad-agent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@
"from langgraph.graph import START, MessagesState, StateGraph\n",
"from langgraph.prebuilt import ToolNode, tools_condition\n",
"\n",
"sys_msg = SystemMessage(content=\"\"\"You are a helpful assistant tasked with answering the user queries for WRI data API.\n",
"sys_msg = SystemMessage(\n",
" content=\"\"\"You are a helpful assistant tasked with answering the user queries for WRI data API.\n",
"Use the `location-tool` to get iso, adm1 & adm2 of any region or place.\n",
"Use the `glad-weekly-alerts-tool` to get forest fire information for a particular year. Think through the solution step-by-step first and then execute.\n",
"Use the `barchart_tool` to plot the data as a barchart & return as an image.\n",
Expand All @@ -94,7 +95,8 @@
"1. Use the `location_tool` to get iso, adm1, adm2 for place `Milan` by passing `query=Milan`\n",
"2. Pass iso, adm1, adm2 along with year `2024` as args to `glad-weekly-alerts-tool` to get data about forest fire alerts.\n",
"3. Use the `barchart-tool` to create a barchart of the dataset\n",
"\"\"\")"
"\"\"\"\n",
")"
]
},
{
Expand Down Expand Up @@ -156,6 +158,7 @@
"from IPython.display import Image, display\n",
"import base64\n",
"\n",
"\n",
"def display_in_notebook(base64_string):\n",
" image_data = base64.b64decode(base64_string)\n",
" display(Image(image_data))"
Expand Down
8 changes: 7 additions & 1 deletion nbs/03_wri-datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,13 @@
"metadata": {},
"outputs": [],
"source": [
"ds_list = [\"gadm__tcl__adm2_summary\", \"gadm__viirs__iso_weekly_alerts\", \"umd_tree_cover_gain\", \"wri_tropical_tree_cover\", \"gadm__burned_areas__adm2_weekly_alerts\"]"
"ds_list = [\n",
" \"gadm__tcl__adm2_summary\",\n",
" \"gadm__viirs__iso_weekly_alerts\",\n",
" \"umd_tree_cover_gain\",\n",
" \"wri_tropical_tree_cover\",\n",
" \"gadm__burned_areas__adm2_weekly_alerts\",\n",
"]"
]
},
{
Expand Down
51 changes: 21 additions & 30 deletions nbs/04_rag-datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,28 +77,28 @@
"def format_dataset_metadata(dataset):\n",
" \"\"\"\n",
" Formats dataset metadata into a readable string.\n",
" \n",
"\n",
" Args:\n",
" dataset (dict): Dictionary containing dataset information with metadata\n",
" \n",
"\n",
" Returns:\n",
" str: Formatted metadata string or None if required fields are missing\n",
" \"\"\"\n",
" try:\n",
" metadata = dataset.get(\"metadata\")\n",
" if not metadata or not metadata.get(\"overview\"):\n",
" return None\n",
" \n",
"\n",
" # Define the fields to include and their labels\n",
" fields = [\n",
" (\"title\", \"Title\"),\n",
" (\"overview\", \"Overview\"),\n",
" (\"cautions\", \"Caution\"),\n",
" (\"function\", \"Function\"),\n",
" (\"geographic_coverage\", \"Geographic Coverage\"),\n",
" (\"tags\", \"Tags\")\n",
" (\"tags\", \"Tags\"),\n",
" ]\n",
" \n",
"\n",
" # Build the content string\n",
" content_parts = []\n",
" for field_name, label in fields:\n",
Expand All @@ -108,17 +108,18 @@
" if field_name == \"tags\" and isinstance(value, list):\n",
" value = \", \".join(value)\n",
" content_parts.append(f\"{label}: {value}\")\n",
" \n",
"\n",
" return \"\\n\".join(content_parts)\n",
" \n",
"\n",
" except Exception as e:\n",
" print(f\"Error processing dataset metadata: {e}\")\n",
" return None\n",
"\n",
"\n",
"def save_datasets_to_csv(datasets, output_file):\n",
" \"\"\"\n",
" Saves dataset information to a CSV file using pandas.\n",
" \n",
"\n",
" Args:\n",
" datasets (dict): Dictionary containing dataset information\n",
" output_file (str): Name of the output CSV file\n",
Expand All @@ -127,28 +128,25 @@
" # Create lists to store data\n",
" dataset_ids = []\n",
" formatted_contents = []\n",
" \n",
"\n",
" # Process each dataset\n",
" for dataset in datasets[\"data\"]:\n",
" dataset_id = dataset.get(\"dataset\")\n",
" formatted_content = format_dataset_metadata(dataset)\n",
" \n",
"\n",
" if dataset_id and formatted_content:\n",
" dataset_ids.append(dataset_id)\n",
" formatted_contents.append(formatted_content)\n",
" \n",
"\n",
" # Create DataFrame\n",
" df = pd.DataFrame({\n",
" 'dataset': dataset_ids,\n",
" 'content': formatted_contents\n",
" })\n",
" \n",
" df = pd.DataFrame({\"dataset\": dataset_ids, \"content\": formatted_contents})\n",
"\n",
" # Save to CSV\n",
" df.to_csv(output_file, index=False, encoding='utf-8')\n",
" df.to_csv(output_file, index=False, encoding=\"utf-8\")\n",
" print(f\"Successfully saved to {output_file}\")\n",
" \n",
"\n",
" return df # Return DataFrame for potential further analysis\n",
" \n",
"\n",
" except Exception as e:\n",
" print(f\"Error saving CSV file: {e}\")\n",
" return None"
Expand Down Expand Up @@ -222,8 +220,8 @@
"metadata": {},
"outputs": [],
"source": [
"texts = df['content'].tolist()\n",
"metadatas = [{'dataset': dataset} for dataset in df['dataset'].tolist()]\n",
"texts = df[\"content\"].tolist()\n",
"metadatas = [{\"dataset\": dataset} for dataset in df[\"dataset\"].tolist()]\n",
"ids = [f\"doc_{i}\" for i in range(len(texts))]"
]
},
Expand All @@ -236,11 +234,7 @@
"source": [
"%%time\n",
"vectorstore = Chroma.from_texts(\n",
" texts=texts,\n",
" embedding=embedder,\n",
" metadatas=metadatas,\n",
" ids=ids,\n",
" persist_directory=db\n",
" texts=texts, embedding=embedder, metadatas=metadatas, ids=ids, persist_directory=db\n",
")"
]
},
Expand All @@ -251,10 +245,7 @@
"metadata": {},
"outputs": [],
"source": [
"db = Chroma(\n",
" persist_directory=\"../data/chroma_db\", \n",
" embedding_function=embedder\n",
")"
"db = Chroma(persist_directory=\"../data/chroma_db\", embedding_function=embedder)"
]
},
{
Expand Down
57 changes: 24 additions & 33 deletions nbs/05_router-agent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"outputs": [],
"source": [
"from dotenv import load_dotenv\n",
"\n",
"_ = load_dotenv(\"../.env\")"
]
},
Expand All @@ -19,6 +20,7 @@
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"..\")"
]
},
Expand Down Expand Up @@ -66,10 +68,7 @@
"outputs": [],
"source": [
"embedder = OllamaEmbeddings(model=\"nomic-embed-text\")\n",
"db = Chroma(\n",
" persist_directory=\"../data/chroma_db\", \n",
" embedding_function=embedder\n",
")"
"db = Chroma(persist_directory=\"../data/chroma_db\", embedding_function=embedder)"
]
},
{
Expand Down Expand Up @@ -97,12 +96,13 @@
"For specific question on forest fires use the tool call.\n",
"Return JSON with single key, route, that is 'vectorstore' or 'glad-tool' depending on the question.\"\"\"\n",
"\n",
"queries = [\"I am interested in biodiversity conservation in Argentina\", \n",
" \"I would like to explore helping with forest loss in Amazon\",\n",
" \"show datasets related to mangrooves\",\n",
" \"find forest fires in milan for the year 2022\",\n",
" \"show stats on forest fires over Ihorombe for 2021\"\n",
" ]"
"queries = [\n",
" \"I am interested in biodiversity conservation in Argentina\",\n",
" \"I would like to explore helping with forest loss in Amazon\",\n",
" \"show datasets related to mangrooves\",\n",
" \"find forest fires in milan for the year 2022\",\n",
" \"show stats on forest fires over Ihorombe for 2021\",\n",
"]"
]
},
{
Expand All @@ -115,12 +115,7 @@
"# tests\n",
"for query in queries:\n",
" response = llm_json_mode.invoke(\n",
" [SystemMessage(content=router_instructions)]\n",
" + [\n",
" HumanMessage(\n",
" content=query\n",
" )\n",
" ]\n",
" [SystemMessage(content=router_instructions)] + [HumanMessage(content=query)]\n",
" )\n",
" response = json.loads(response.content)\n",
" print(query, \" ---> \", response[\"route\"])"
Expand Down Expand Up @@ -189,7 +184,9 @@
"def make_context(docs):\n",
" fmt_docs = []\n",
" for doc in docs:\n",
" url = f\"https://data-api.globalforestwatch.org/dataset/{doc.metadata['dataset']}\"\n",
" url = (\n",
" f\"https://data-api.globalforestwatch.org/dataset/{doc.metadata['dataset']}\"\n",
" )\n",
" content = \"URL: \" + url + \"\\n\" + doc.page_content\n",
" fmt_docs.append(content)\n",
" return \"\\n\\n\".join(fmt_docs)"
Expand Down Expand Up @@ -310,6 +307,7 @@
" documents = retriever.invoke(question)\n",
" return {\"documents\": documents}\n",
"\n",
"\n",
"def generate(state):\n",
" print(\"---GENERATE---\")\n",
" question = state[\"question\"]\n",
Expand All @@ -322,20 +320,24 @@
" generation = llm.invoke([HumanMessage(content=rag_prompt_fmt)])\n",
" return {\"generation\": generation, \"loop_step\": loop_step + 1}\n",
"\n",
"\n",
"def assistant(state):\n",
" sys_msg = SystemMessage(content=\"\"\"You are a helpful assistant tasked with answering the user queries for WRI data API.\n",
" sys_msg = SystemMessage(\n",
" content=\"\"\"You are a helpful assistant tasked with answering the user queries for WRI data API.\n",
" Use the `location-tool` to get iso, adm1 & adm2 of any region or place.\n",
" Use the `glad-weekly-alerts-tool` to get forest fire information for a particular year. Think through the solution step-by-step first and then execute.\n",
" \n",
" For eg: If the query is \"Find forest fires in Milan for the year 2024\"\n",
" Steps\n",
" 1. Use the `location_tool` to get iso, adm1, adm2 for place `Milan` by passing `query=Milan`\n",
" 2. Pass iso, adm1, adm2 along with year `2024` as args to `glad-weekly-alerts-tool` to get information about forest fire alerts.\n",
" \"\"\")\n",
" \"\"\"\n",
" )\n",
" if not state[\"messages\"]:\n",
" state[\"messages\"] = [HumanMessage(state[\"question\"])]\n",
" return {\"messages\": [llm_with_tools.invoke([sys_msg] + state[\"messages\"])]}\n",
"\n",
"\n",
"tool_node = ToolNode(tools)"
]
},
Expand Down Expand Up @@ -376,11 +378,7 @@
" print(\"---ROUTER---\")\n",
" response = llm_json_mode.invoke(\n",
" [SystemMessage(content=router_instructions)]\n",
" + [\n",
" HumanMessage(\n",
" content=state[\"question\"]\n",
" )\n",
" ]\n",
" + [HumanMessage(content=state[\"question\"])]\n",
" )\n",
" route = json.loads(response.content)[\"route\"]\n",
" if route == \"vectorstore\":\n",
Expand Down Expand Up @@ -439,18 +437,11 @@
"wf.add_node(\"tools\", tool_node)\n",
"\n",
"wf.set_conditional_entry_point(\n",
" router,\n",
" {\n",
" \"retrieve\": \"retrieve\",\n",
" \"assistant\": \"assistant\"\n",
" }\n",
" router, {\"retrieve\": \"retrieve\", \"assistant\": \"assistant\"}\n",
")\n",
"wf.add_edge(\"retrieve\", \"generate\")\n",
"wf.add_edge(\"generate\", END)\n",
"wf.add_conditional_edges(\n",
" \"assistant\",\n",
" tools_condition\n",
")\n",
"wf.add_conditional_edges(\"assistant\", tools_condition)\n",
"wf.add_edge(\"tools\", \"assistant\")\n",
"wf.add_edge(\"assistant\", END)\n",
"\n",
Expand Down
Loading

0 comments on commit 4efc4a7

Please sign in to comment.