Skip to content

Commit

Permalink
Merge pull request #29 from wri/gee-improvements
Browse files Browse the repository at this point in the history
Gee improvements
  • Loading branch information
yellowcap authored Nov 29, 2024
2 parents 690e145 + 24c15b4 commit f636bda
Show file tree
Hide file tree
Showing 13 changed files with 268 additions and 22 deletions.
6 changes: 5 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@ ANTHROPIC_API_KEY=<anthropic-api-key>
LANGFUSE_SECRET_KEY=<langfuse-secret-key>
LANGFUSE_PUBLIC_KEY=<langfuse-public-key>
LANGFUSE_HOST=http://localhost:3000
TAVILY_API_KEY=<tavily-api-key>
GEE_PROJECT_ID=<gee-proj-id>
GEE_PRIVATE_KEY_ID=<gee-private-key-id>
GEE_PRIVATE_KEY=<gee-private-key>
GEE_CLIENT_EMAIL=<gee-client-email>
GEE_CLIENT_ID=<gee-client-id>
4 changes: 1 addition & 3 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def event_stream(query: str):

@app.post("/stream")
async def stream(query: Annotated[str, Body(embed=True)]):
return StreamingResponse(
event_stream(query), media_type="application/x-ndjson"
)
return StreamingResponse(event_stream(query), media_type="application/x-ndjson")


# Processes the query and returns the response
Expand Down
8 changes: 2 additions & 6 deletions app/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,8 @@ def generate(state):
st.write(f"**Dataset {idx+1}:** {dataset['explanation']}")
st.write(f"**URL**: {dataset['url']}")
st.write(f"**Tilelayer**: {dataset['tilelayer']}")
if st.button(
f"Show Dataset {idx+1}", key=f"dataset_{idx}"
):
st.session_state["selected_dataset"] = dataset[
"tilelayer"
]
if st.button(f"Show Dataset {idx+1}", key=f"dataset_{idx}"):
st.session_state["selected_dataset"] = dataset["tilelayer"]
except Exception as e:
st.error(f"Error processing response: {str(e)}")

Expand Down
1 change: 1 addition & 0 deletions tests/test_dist_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ def test_distalert_agent():
print(key2, val2)
pass


test_distalert_agent()
3 changes: 1 addition & 2 deletions tests/test_dist_alerts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

from zeno.tools.dist.dist_alerts_tool import dist_alerts_tool
from zeno.tools.distalert.dist_alerts_tool import dist_alerts_tool


def test_dist_alert_tool():
Expand Down
4 changes: 2 additions & 2 deletions zeno/agents/distalert/utils/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from langgraph.prebuilt import ToolNode

from zeno.agents.maingraph.models import ModelFactory
from zeno.tools.dist.context_layer_tool import context_layer_tool
from zeno.tools.dist.dist_alerts_tool import dist_alerts_tool
from zeno.tools.distalert.context_layer_tool import context_layer_tool
from zeno.tools.distalert.dist_alerts_tool import dist_alerts_tool
from zeno.tools.location.location_tool import location_tool

_ = load_dotenv(".env")
Expand Down
8 changes: 4 additions & 4 deletions zeno/agents/docfinder/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from langgraph.prebuilt import ToolNode, tools_condition

from zeno.agents.docfinder.utils.nodes import (
agent,
generate,
grade_documents,
rewrite,
agent,
generate,
grade_documents,
rewrite,
)
from zeno.agents.docfinder.utils.state import GraphState
from zeno.tools.docretrieve.document_retrieve_tool import retriever_tool
Expand Down
6 changes: 2 additions & 4 deletions zeno/agents/maingraph/utils/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ def maingraph(state, config: RunnableConfig):
print(f"---SLASH COMMAND {route} DETECTED---")
if route not in ["layerfinder", "docfinder", "firealert", "distalert"]:
raise ValueError(f"Slash-command {route} not valid")
# Remove slash command
question = state["question"].replace(f"/{route} ", "")
else:
model_id = config["configurable"].get("model_id", "gpt-4o-mini")
model = ModelFactory().get(model_id, json_mode=True)

response = model.invoke([sys_msg] + [HumanMessage(content=state["question"])])
response = model.invoke([sys_msg] + [HumanMessage(content=state["question"])])
route = json.loads(response.content)["route"]

if route == "layerfinder":
print("---ROUTING-TO-LAYERFINDER---")
elif route == "docfinder":
Expand Down
1 change: 1 addition & 0 deletions zeno/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from agents.maingraph.models import ModelFactory
from fastapi import Body, FastAPI, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse

from langfuse.callback import CallbackHandler

app = FastAPI()
Expand Down
28 changes: 28 additions & 0 deletions zeno/tools/distalert/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

# Auth for GEE

To authenticate the Google Earth Engine (GEE) Python API using an API key, you can configure your environment to use a service account instead of manual authentication. This involves generating a service account in the Google Cloud Console, assigning it an API key, and linking it to your Earth Engine project.

Here’s how to do it step-by-step:

### Step 1: Enable Earth Engine API in Google Cloud Console

- Go to the Google Cloud Console.
- Create a project (or use an existing one).
- Enable the Earth Engine API for your project:
- Navigate to APIs & Services > Library.
- Search for "Earth Engine API" and enable it.

### Step 2: Create a Service Account

- Navigate to APIs & Services > Credentials in the Cloud Console.
- Click Create Credentials and select Service Account.
- Fill in the required fields, then click Create.
- Assign the "Editor" role (or other roles as required).
- Download the JSON key file for the service account.

### Step 3: Grant the Service Account Access to Earth Engine

- Go to the Earth Engine Service Account Permissions.
- Add the email of the service account (e.g., [email protected]) under Manage Permissions.
- Assign it the desired role, typically "Can Edit" or "Can View".
59 changes: 59 additions & 0 deletions zeno/tools/distalert/context_layer_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Literal, Optional

from langchain_core.prompts import PromptTemplate
from langchain_core.tools import tool
from pydantic import BaseModel, Field

from zeno.agents.maingraph.models import ModelFactory

# init_gee()


class grade(BaseModel):
"""Binary score for relevance check."""

binary_score: Literal["yes", "no"] = Field(
description="Relevance score 'yes' or 'no'"
)


prompt = PromptTemplate(
template="""You are a deciding if a context layer is required for analysing disturbance alerts. \n
Here is the user question: {question} \n
If the question asks for grouping the disturbance alerts by landcover, decide in favor of using a context layer. \n
Give a binary score 'yes' or 'no' score to indicate whether a landcover layer should be used.""",
input_variables=["question"],
)


class ContextLayerInput(BaseModel):
"""Input schema for context layer tool"""

question: str = Field(description="The question from the user")


model = ModelFactory().get("claude-3-5-sonnet-latest").with_structured_output(grade)

chain = prompt | model


@tool("context-layer-tool", args_schema=ContextLayerInput, return_direct=False)
def context_layer_tool(question: str) -> Optional[str]:
"""
Determines whether the question asks for summarizing by land cover.
"""

print("---CHECK CONTEXT LAYER TOOL---")

scored_result = chain.invoke({"question": question})

score = scored_result.binary_score

if score == "yes":
print("---DECISION: USE LANDCOVER---")
return "WRI/SBTN/naturalLands/v1/2020"
elif score == "no":
print("---DECISION: DONT USE LANDCOVER---")
return None
else:
raise ValueError(f"score was not yes or no, it was {score}")
126 changes: 126 additions & 0 deletions zeno/tools/distalert/dist_alerts_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import List, Literal, Optional, Union

import ee
import googleapiclient
from dotenv import load_dotenv
from langchain_core.tools import tool
from pydantic import BaseModel, Field

from zeno.tools.distalert.gee import init_gee
from zeno.tools.location.location_matcher import LocationMatcher

# Load environment variables
load_dotenv(".env")

# Initialize gee
init_gee()

location_matcher = LocationMatcher("data/gadm41_PRT.gpkg")


class DistAlertsInput(BaseModel):
"""Input schema for dist tool"""

# class Config:
# arbitrary_types_allowed = True
features: List[str] = Field(
description="List of GADM ids are used for zonal statistics"
)
# features: FeatureCollection = Field(description="Feature collection that is used for zonal statistics")
landcover: Optional[str] = Field(
default=None, description="Landcover layer name to group zonal statistics by"
)
threshold: Optional[Literal[1, 2, 3, 4, 5, 6, 7, 8]] = Field(
default=5, description="Threshold for disturbance alert scale"
)


def print_meta(
layer: Union[ee.image.Image, ee.imagecollection.ImageCollection]
) -> None:
"""Print layer metadata"""
# Get all metadata as a dictionary
metadata = layer.getInfo()

# Print metadata
print("Image Metadata:")
for key, value in metadata.items():
print(f"{key}: {value}")


@tool(
"dist-alerts-tool",
args_schema=DistAlertsInput,
return_direct=True,
response_format="content_and_artifact",
)
def dist_alerts_tool(
features: List[str],
landcover: Optional[str] = None,
threshold: Optional[Literal[1, 2, 3, 4, 5, 6, 7, 8]] = 5,
) -> dict:
"""
Dist alerts tool
This tool quantifies vegetation disturbance alerts over an area of interest
and summarizes the alerts in statistics by landcover types.
"""
print("---DIST ALERTS TOOL---")
distalerts = ee.ImageCollection(
"projects/glad/HLSDIST/current/VEG-DIST-STATUS"
).mosaic()

gee_features = ee.FeatureCollection(
[ee.FeatureCollection(location_matcher.get_by_id(id)) for id in features]
)
gee_features = gee_features.flatten()

combo = distalerts.gte(threshold)

if landcover:
landcover_layer = ee.Image(landcover).select("classification")
combo = combo.addBands(landcover_layer)
zone_stats = combo.reduceRegions(
collection=gee_features,
reducer=ee.Reducer.count().group(groupField=1, groupName="classification"),
scale=30,
).getInfo()
zone_stats_result = {
feat["properties"]["GID_3"]: feat["properties"]["groups"]
for feat in zone_stats["features"]
}
vectorize = landcover_layer.updateMask(distalerts.gte(threshold).selfMask())
else:
zone_stats = (
distalerts.gte(threshold)
.selfMask()
.reduceRegions(
collection=gee_features,
reducer=ee.Reducer.count(),
scale=30,
)
.getInfo()
)
zone_stats_result = {
feat["properties"]["GID_3"]: [
{"classification": 1, "count": feat["properties"]["count"]}
]
for feat in zone_stats["features"]
}
vectorize = distalerts.gte(threshold).selfMask()

# Vectorize the masked classification
vectors = vectorize.reduceToVectors(
geometryType="polygon",
scale=100,
maxPixels=1e8,
geometry=gee_features,
eightConnected=True,
)

try:
vectorized = vectors.getInfo()
except googleapiclient.errors.HttpError:
vectorized = {}

return zone_stats_result, vectorized
36 changes: 36 additions & 0 deletions zeno/tools/distalert/gee.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os

import ee
from dotenv import load_dotenv
from google.oauth2.service_account import Credentials

_ = load_dotenv()


def init_gee() -> None:
"""Initialize and authenticate gee"""

service_account_info = {
"type": "service_account",
"project_id": os.environ.get("GEE_PROJECT_ID"),
"private_key_id": os.environ.get("GEE_PRIVATE_KEY_ID"),
"private_key": os.environ.get("GEE_PRIVATE_KEY"),
"client_email": os.environ.get("GEE_CLIENT_EMAIL"),
"client_id": os.environ.get("GEE_CLIENT_ID"),
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": os.environ.get("GEE_CLIENT_CERT_URL"),
"universe_domain": "googleapis.com",
}

scopes = [
"https://www.googleapis.com/auth/earthengine",
"https://www.googleapis.com/auth/cloud-platform",
]

credentials = Credentials.from_service_account_info(
service_account_info, scopes=scopes
)

ee.Initialize(credentials)

0 comments on commit f636bda

Please sign in to comment.