Skip to content

Commit

Permalink
Revert changes that were pushed by accident
Browse files Browse the repository at this point in the history
  • Loading branch information
yellowcap committed Dec 20, 2024
1 parent c677a33 commit 5aa3d40
Showing 1 changed file with 12 additions and 40 deletions.
52 changes: 12 additions & 40 deletions zeno/tools/location/location_tool.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
import os
from typing import Tuple

import requests
import fiona
import duckdb
from langchain_chroma.vectorstores import Chroma
from langchain_core.tools import tool
from langchain_ollama import OllamaEmbeddings
from pydantic import BaseModel, Field

from dotenv import load_dotenv
gadm = fiona.open("data/gadm_410_small.gpkg")

vectorstore = Chroma(
persist_directory="data/chroma_gadm",
embedding_function=OllamaEmbeddings(
model="nomic-embed-text", base_url=os.environ["OLLAMA_BASE_URL"]
),
collection_name="gadm",
create_collection_if_not_exists=False,
)

load_dotenv()

class LocationInput(BaseModel):
"""Input schema for location finder tool"""

query: str = Field(
description="Name of the location to search for. Can be a city, region, or country name. Each of these values separated by commas"
description="Name of the location to search for. Can be a city, region, or country name."
)


Expand All @@ -32,46 +38,12 @@ def location_tool(query: str) -> Tuple[list, list]:
Returns a list of IDs with matches at different administrative levels
Args:
query (str): Location name to search for, different parts of the string separated by commas
query (str): Location name to search for
Returns:
matches (Tuple[list, list]): GDAM feature IDs their geojson feature collections
"""
print("---LOCATION-TOOL---")
url = f"https://api.opencagedata.com/geocode/v1/json?q={query}&key={os.environ.get('OPENCAGE_API_KEY')}"
print(url)

response = requests.get(url)
lat = response.json()["results"][0]["geometry"]["lat"]
lon = response.json()["results"][0]["geometry"]["lng"]

overture = duckdb.sql(
f"""INSTALL httpfs; INSTALL spatial;
LOAD spatial; -- noqa
LOAD httpfs; -- noqa
-- Access the data on AWS in this example
SET s3_region='us-west-2';
SELECT
subtype, names, division_id, ST_AsGeoJSON(geometry)
FROM
--'/Users/tam/Desktop/overture_division_area/*.parquet'
read_parquet('s3://overturemaps-us-west-2/release/2024-12-18.0/theme=divisions/type=division_area/*', filename=true, hive_partitioning=1)
WHERE bbox.xmin < {lon}
AND bbox.xmax > {lon}
AND bbox.ymin < {lat}
AND bbox.ymax > {lat}
AND ST_Intersects(geometry, ST_Point({lon}, {lat}))
LIMIT 3;
"""
)
print(overture)






matches = vectorstore.similarity_search(query, k=3)
fids = [int(dat.metadata["fid"]) for dat in matches]
aois = [gadm[fid] for fid in fids]
Expand Down

0 comments on commit 5aa3d40

Please sign in to comment.