forked from langchain-ai/opengpts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dnd.py
136 lines (108 loc) · 4.31 KB
/
dnd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import json
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.pydantic_v1 import BaseModel, Field
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
from permchain import BaseCheckpointAdapter, Channel, Pregel
from permchain.channels import LastValue, Topic
character_system_msg = """You are a dungeon master for a game of dungeons and dragons.
You are interacting with the first (and only) player in the game. \
Your job is to collect all needed information about their character. This will be used in the quest. \
Feel free to ask them as many questions as needed to get to the relevant information.
The relevant information is:
- Character's name
- Character's race (or species)
- Character's class
- Character's alignment
Once you have gathered enough information, write that info to `notebook`."""
class CharacterNotebook(BaseModel):
"""Notebook to write information to"""
player_info: str = Field(
description="Information about a player that you will remember over time"
)
character_prompt = ChatPromptTemplate.from_messages(
[("system", character_system_msg), MessagesPlaceholder(variable_name="messages")]
)
gameplay_system_msg = """You are a dungeon master for a game of dungeons and dragons.
You are leading a quest of one person. Their character description is here:
{character}
A summary of the game state is here:
{state}"""
game_prompt = ChatPromptTemplate.from_messages(
[("system", gameplay_system_msg), MessagesPlaceholder(variable_name="messages")]
)
class StateNotebook(BaseModel):
"""Notebook to write information to"""
state: str = Field(description="Information about the current game state")
state_prompt = ChatPromptTemplate.from_messages(
[
("system", gameplay_system_msg),
MessagesPlaceholder(variable_name="messages"),
(
"human",
"If any updates to the game state are neccessary, please update the state notebook. If none are, just say no.",
),
]
)
def _maybe_update_state(message: AnyMessage):
if "function_call" in message.additional_kwargs:
return Channel.write_to(
"messages",
state=json.loads(message.additional_kwargs["function_call"]["arguments"])[
"state"
],
)
def _maybe_update_character(message: AnyMessage):
if "function_call" in message.additional_kwargs:
args = json.loads(message.additional_kwargs["function_call"]["arguments"])
return Channel.write_to(
messages=AIMessage(content="Ready for the quest?"),
character=args["player_info"],
)
def create_dnd_bot(llm: BaseChatModel, checkpoint: BaseCheckpointAdapter):
character_model = llm.bind(
functions=[convert_pydantic_to_openai_function(CharacterNotebook)],
)
game_chain = game_prompt | llm | Channel.write_to("messages", check_update=True)
state_model = llm.bind(
functions=[convert_pydantic_to_openai_function(StateNotebook)],
stream=False,
)
state_chain = (
Channel.subscribe_to(["check_update"]).join(["messages", "character", "state"])
| state_prompt
| state_model
| _maybe_update_state
)
character_chain = (
character_prompt
| character_model
| Channel.write_to("messages")
| _maybe_update_character
)
def _route_to_chain(_input):
messages = _input["messages"]
if not messages:
return
if not _input["character"] and isinstance(messages[-1], HumanMessage):
return character_chain
elif isinstance(messages[-1], HumanMessage):
return game_chain
executor = (
Channel.subscribe_to(["messages"]).join(["character", "state"])
| _route_to_chain
)
dnd = Pregel(
chains={"executor": executor, "update_state": state_chain},
channels={
"messages": Topic(AnyMessage, accumulate=True),
"character": LastValue(str),
"state": LastValue(str),
"check_update": LastValue(bool),
},
input=["messages"],
output=["messages"],
checkpoint=checkpoint,
)
return dnd