-
Notifications
You must be signed in to change notification settings - Fork 5
/
agent.py
66 lines (58 loc) · 1.98 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from langchain.agents import AgentType, Tool, initialize_agent
from .chains import (
build_conversational_retrieval_chain,
build_csv_chain,
build_sql_database_chain,
)
from .config import AppConfig, setup_logging
logger = setup_logging()
def parse_agent_output(agent_resp: dict) -> str:
if isinstance(agent_resp["output"], dict):
ai_message = agent_resp["output"]["answer"]
else:
ai_message = agent_resp["output"]
return ai_message
def build_agent(config: AppConfig):
conversational_retrieval_chain = build_conversational_retrieval_chain(config=config)
penguin_sql_database_chain = build_sql_database_chain(
config=config,
db_uri=f"sqlite:///{config.repo_dir}/data/sqlite/palmer_penguins.db",
)
iris_csv_chain = build_csv_chain(
config=config, csv_path=f"{config.repo_dir}/data/csv/iris.csv"
)
tools = [
Tool(
name="MLOps",
func=conversational_retrieval_chain.__call__,
description="""
Useful for when you need to answer questions about mlops, mlrun,
iguazio, machine learning, data science, or other related topics.
""",
return_direct=True,
),
Tool(
name="Palmer penguins",
func=penguin_sql_database_chain.run,
description="""
Useful for when you need to answer questions about Adelie,
Gentoo, or Chinstrap penguins using a SQL database.
""",
return_direct=True,
),
Tool(
name="Iris",
func=iris_csv_chain.run,
description="""
Useful for when you need to answer questions about iris
flowers. Includes petal sizes in cm per species.
""",
return_direct=True,
),
]
return initialize_agent(
tools=tools,
llm=config.get_llm(),
agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
verbose=True,
)